vignettes/partial-evaluation.Rmd
partial-evaluation.Rmd
An approach to runtime optimisation is partial evaluation. It isn’t that complicated an idea; if you are computing the same thing again and again, you would be better off computing it once and remember the result. This idea is used many places, both implementations and in the design of algorithms, where we call the idea things like dynamic programming or memorisation. When it comes to functions, though, it involves evaluating the parts of a function we are able to with the parameters we can fix, and keeping the rest of the function as code that must be evaluated at a later time.
Consider a simple example where we have a function that scales one vector, y
, using the mean and standard deviation of another vector x
:
The function takes both vectors as input, but we can curry it to get a function of one argument, x
, that returns another that also takes one argument, y
, and then computes the same expression.
This alone can seem a little pointless, but we see that mean(x)
and sd(x)
only depend on x
, so if we need to use the function on many different y
vectors, with the same x
value, we could rewrite it so these two quantities are computed as soon as x
is known, and such that we can reuse these computations for the different values of y
.
Currying alone doesn’t change the performance, but combined with the partial evaluation—computing what we can when we know x
and before we know y
, can have a substantial effect.
time_uncurry <- function(f) {
n <- 500
x <- rnorm(1000, 0, 1)
y <- rnorm(1000, 10, 100)
replicate(n, f(x,y))
}
time_curry <- function(f) {
n <- 500
x <- rnorm(1000, 0, 1)
y <- rnorm(1000, 10, 100)
g <- f(x)
replicate(n, g(y))
}
microbenchmark::microbenchmark(
time_uncurry(rescale),
time_curry(rescale_curry),
time_curry(rescale_curry_pe)
)
#> Unit: milliseconds
#> expr min lq mean median
#> time_uncurry(rescale) 15.046865 15.642861 21.74142 16.380228
#> time_curry(rescale_curry) 14.997960 15.410322 20.60673 15.989984
#> time_curry(rescale_curry_pe) 3.367868 4.878541 11.79211 5.703464
#> uq max neval
#> 19.74924 44.47971 100
#> 17.13966 110.54133 100
#> 26.65673 29.77731 100
The actual performance difference depends, of course, on how many times we call the partial-evaluated function and how complex it is.
We can use foolbox
to automatically partial-evaluate a function by replacing a variable for a value. In the simplest possible form, we simply substitute a value for a variable:
library(foolbox)
subst_var_callback <- function(var, val) {
function(expr, ...) {
if (expr == var) rlang::expr(!!val)
else expr
}
}
pe <- function(fn, var, val) {
var <- rlang::enexpr(var) ; stopifnot(rlang::is_symbol(var))
force(val)
fn %>% rewrite() %>% rewrite_with(
rewrite_callbacks() %>%
with_symbol_callback(subst_var_callback(var, val))
) %>% remove_formal_(var)
}
We can see this in action by trying it out on the scale
function:
test_x <- rnorm(3, 0, 1) %>% round(digits = 2)
rescale_pe <- rescale %>% pe(x, test_x)
rescale_pe
#> function (y)
#> {
#> (y - mean(c(0.82, 0.17, 0.39)))/sd(c(0.82, 0.17, 0.39))
#> }
test_y <- rnorm(5, 10, 100)
rescale(test_x, test_y)
#> [1] 207.4550 376.3745 -762.1280 -334.4448 164.2914
rescale_pe(test_y)
#> [1] 207.4550 376.3745 -762.1280 -334.4448 164.2914
We have inserted a value for x
, but that is all. We have also, as it turns out, managed to make our program slower, because even though we run the partially evaluated function 500 times, the cost in runtime of transforming it outweighs whatever gain we get from the transformation.
microbenchmark::microbenchmark(
time_uncurry(rescale),
time_curry(function(x) pe(rescale, x, x))
)
#> Unit: milliseconds
#> expr min lq mean
#> time_uncurry(rescale) 14.15367 15.3082 18.67470
#> time_curry(function(x) pe(rescale, x, x)) 20.40191 21.9884 25.97323
#> median uq max neval
#> 16.02201 16.70747 46.44439 100
#> 22.61048 23.66992 52.10957 100
This is perhaps not so surprising considering that we do not actually evaluate the mean(x)
and sd(x)
expressions we wanted to evaluate.
We can fix this by trying to evaluate all calls if all their arguments are evaluated, i.e. if all their arguments are atomic.
eval_attempt <- function(expr, env) {
tryCatch(rlang::expr(!!eval(expr, env)),
error = function(e) expr)
}
call_callback <- function(expr, env, ...) {
call_args_atomic <- rlang::call_args(expr) %>%
Map(rlang::is_atomic, .) %>% unlist
if (any(!call_args_atomic)) expr
else eval_attempt(expr, env)
}
pe <- function(fn, var, val) {
var <- rlang::enexpr(var) ; stopifnot(rlang::is_symbol(var))
force(val)
fn %>% rewrite() %>% rewrite_with(
rewrite_callbacks() %>%
with_symbol_callback(subst_var_callback(var,val)) %>%
with_call_callback(call_callback)
) %>% remove_formal_(var)
}
With this implementation we will actually evaluate the function calls in scale
when we substitute x
for a value:
rescale_pe <- rescale %>% pe(x, test_x)
rescale_pe
#> function (y)
#> {
#> (y - 0.46)/0.330605505096331
#> }
rescale(test_x, test_y)
#> [1] 207.4550 376.3745 -762.1280 -334.4448 164.2914
rescale_pe(test_y)
#> [1] 207.4550 376.3745 -762.1280 -334.4448 164.2914
We can also see that this improves the performance, although we only get back to the performance we got before the transformation, so for this example we are not gaining enough for it to be worth the effort. Stil, the idea is the important part, and it does show that you can do this with foolbox
…
microbenchmark::microbenchmark(
time_uncurry(rescale),
time_curry(function(x) pe(rescale, x, x))
)
#> Unit: milliseconds
#> expr min lq mean
#> time_uncurry(rescale) 14.034317 14.97801 16.83756
#> time_curry(function(x) pe(rescale, x, x)) 9.028978 10.33661 12.24643
#> median uq max neval
#> 15.91812 16.40196 45.16957 100
#> 10.88464 11.52141 39.72597 100
Since we use env
in the evaluation we also capture local variables where the function is defined:
What about other local functions?
rescale_ms <- function(x, y, mean, sd)
(y - mean(x))/sd(x)
rescale_ms %>% pe(x, test_x)
#> function (y, mean, sd)
#> (y - 0.46)/0.330605505096331
We evaluate mean(x)
and sd(x)
in the environment of rescale_ms
but not its execution environment. So we evaluate the expressions using the global functions. This will only give the right behaviour if it is these exact two functions that are passed as the arguments.
In general, we should not evaluate functions that are local variables, at least not in the function’s environment. We can fix this by checking if a function name is a local variable before we evaluate a call:
is_bound <- function(var, expr) {
(rlang::is_symbol(var) || is.character(var)) &&
as.character(var) %in% attr(expr, "bound")
}
call_callback <- function(expr, env, ...) {
call_args_atomic <- rlang::call_args(expr) %>%
Map(rlang::is_atomic, .) %>% unlist
if (any(!call_args_atomic) || is_bound(expr[[1]], expr)) expr
else eval_attempt(expr, env)
}
pe <- function(fn, var, val) {
var <- rlang::enexpr(var) ; stopifnot(rlang::is_symbol(var))
force(val)
fn %>% rewrite() %>% rewrite_with(
rewrite_callbacks() %>%
with_symbol_callback(subst_var_callback(var,val)) %>%
with_call_callback(call_callback)
) %>% remove_formal_(var)
}
rescale %>% pe(x, test_x)
#> function (y)
#> {
#> (y - 0.46)/0.330605505096331
#> }
rescale_ms %>% pe(x, test_x)
#> function (y, mean, sd)
#> (y - mean(c(0.82, 0.17, 0.39)))/sd(c(0.82, 0.17, 0.39))
Since functions are also data, we can do partial evaluation with them:
rescale_ms %>% pe(x, test_x) %>% pe(mean, mean)
#> function (y, sd)
#> (y - 0.46)/sd(c(0.82, 0.17, 0.39))
rescale_ms %>% pe(x, test_x) %>% pe(mean, mean) %>% pe(sd, sd)
#> function (y)
#> (y - 0.46)/0.330605505096331
What about local variables?
rescale_nested <- function(x, y) {
move <- function(x, y) y - mean(x)
move(x, y) / sd(x)
}
rescale_nested %>% pe(x, test_x)
#> function (y)
#> {
#> move <- function(x, y) y - 0.46
#> move(c(0.82, 0.17, 0.39), y)/0.330605505096331
#> }
Here we are substituting x
inside the body of move
. It turns out to be okay because we call move
with x
later, but that is just dumb luck. In general, we don’t want to substitute variables that are local to another local function. We can use a top-down callback to skip a function body if it has the variable we are substituting as a parameter:
skip_overscoped_functions <- function(var_symbol) {
var <- as.character(var_symbol)
function(expr, skip, ...) {
# in `function` expressions, the pair-list is element 2
if (var %in% names(expr[[2]])) {
skip(expr)
}
}
}
pe <- function(fn, var, val) {
var <- rlang::enexpr(var) ; stopifnot(rlang::is_symbol(var))
force(val)
fn %>% rewrite() %>% rewrite_with(
rewrite_callbacks() %>%
with_symbol_callback(subst_var_callback(var,val)) %>%
with_call_callback(call_callback) %>%
add_topdown_callback(`function`, skip_overscoped_functions(var))
) %>% remove_formal_(var)
}
rescale_nested %>% pe(x, test_x)
#> function (y)
#> {
#> move <- function(x, y) y - mean(x)
#> move(c(0.82, 0.17, 0.39), y)/0.330605505096331
#> }
Ok, so now we avoid substituting variables inside a local function, but in this particular case, we would like to substitute in the call to move
so we can get mean(x)
evaluated there.
To handle local functions there is a bit of work to be done. First, we need a way to evaluate them, and that means partial evaluation. We want to substitute the parameters we know the values of and let the rest remain. We can do this by iteratively going through the parameters in a call and substitute variables like this:
pe_call <- function(fun, expr) {
params <- rlang::call_args(match.call(fun, expr))
substitutable <- Map(rlang::is_atomic, params) %>% unlist
not_substitutable <- params[!substitutable]
params <- params[substitutable]
param_names <- names(params)
for (i in seq_along(params)) {
var <- as.symbol(param_names[i])
fun <- fun %>% pe_(var, params[[i]])
}
rlang::expr((!!fun)(!!!not_substitutable))
}
Here, I assume that I have a pe_
function that works like pe
but where var
is already quoted. I define it below (and redefine pe
in terms of it).
With this function we can handle a call. Now, we need to collect all local function expressions and evaluate them into actual functions. We do this in a callback to assignments (<-
and =
). I’m not entirely sure how to deal with variables in local function’s closure. In theory we should be able to do some substitution inside the function and leave those variables alone, but I plan to evaluate the expressions that define the functions, and I need an environment to do this in. This cannot be the execution environment—it doesn’t exist before we call the function—so I can only use the function environment. The local variables do not exist there. So I won’t allow the local function to refer to local variables from outsides its own scope (parameters and local variables inside the function) or the variable we are substituting. This function extracts all symbols used in an expression:
collect_all_symbols <- function(expr) {
expr %>% analyse_expr() %>% analyse_expr_with(
analysis_callbacks() %>% with_symbol_callback(
function(expr, ...) list(symbols = as.character(expr))
)) %>% unlist(use.names = FALSE)
}
and I can use it in this function for getting hold of local functions:
save_local_functions <- function(expr, env, params, var, fun_table, ...) {
if (!rlang::is_symbol(expr[[1]]))
return(expr)
if (!rlang::is_call(expr[[3]]) || length(expr[[3]]) < 2 ||
expr[[3]][[1]] != "function")
return(expr)
fun_name <- as.character(expr[[2]])
fun_formals <- expr[[3]][[2]]
fun_body <- expr[[3]][[3]]
fun_parameters <- names(fun_formals)
fun_used_vars <- collect_all_symbols(fun_body)
local_vars <- attr(expr, "bound")
# figure out if the function uses local variables that are not
# passed through the formal parameters
closure_scope <- setdiff(fun_used_vars, fun_parameters)
local_closure_scope <- intersect(local_vars, closure_scope)
if (length(local_closure_scope) > 2)
return(expr)
if (length(local_closure_scope) == 1 &&
local_closure_scope != as.character(var))
return(expr)
# create the function and save it as an attribute
stopifnot(!exists(fun_name, fun_table))
fun_table[[fun_name]] <- eval(expr[[3]], env)
expr
}
I store the local functions in a table, fun_table
. I will use an environment here, so the callback has side-effects. It is the easiest way to implement this table.
In the call callback I then do partial evaluation on local variables:
call_callback <- function(expr, env, fun_table, ...) {
call_args_atomic <- rlang::call_args(expr) %>%
Map(rlang::is_atomic, .) %>% unlist
fun_name <- as.character(expr[[1]])
if (all(call_args_atomic) && !is_bound(fun_name, expr))
return(eval_attempt(expr, env))
# if it exists in the function table, we can try a partial evaluation
# of it
if (is_bound(fun_name, expr) && exists(fun_name, fun_table)) {
fun <- fun_table[[fun_name]]
# if we have all the arguments, we can just call the function
# and insert the result
if (all(call_args_atomic))
return(do.call(fun, rlang::call_args(expr)))
# otherwise, we have to try a partial substitution of the function
return(pe_call(fun, expr))
}
expr
}
We can now put all this together to get a new pe
(and pe_
) function:
pe_ <- function(fn, var, val) {
stopifnot(rlang::is_symbol(var))
force(val)
fn %>% rewrite() %>% rewrite_with(
rewrite_callbacks() %>%
with_symbol_callback(subst_var_callback(var,val)) %>%
with_call_callback(call_callback) %>%
add_topdown_callback(`function`, skip_overscoped_functions(var)) %>%
add_topdown_callback(`<-`, save_local_functions) %>%
add_topdown_callback(`=`, save_local_functions),
var = var, fun_table = rlang::new_environment()
) %>% remove_formal_(var)
}
pe <- function(fn, var, val) {
var <- rlang::enexpr(var)
pe_(fn, var, val)
}
rescale_nested <- function(x, y) {
move <- function(x, y) y - mean(x)
move(x, y) / sd(x)
}
rescale_nested %>% pe(x, test_x)
#> function (y)
#> {
#> move <- function(x, y) y - mean(x)
#> (function (y)
#> y - 0.46)(y = y)/0.330605505096331
#> }
rescale_nested_pe <- rescale_nested %>% pe(x, test_x)
rescale_nested_pe
#> function (y)
#> {
#> move <- function(x, y) y - mean(x)
#> (function (y)
#> y - 0.46)(y = y)/0.330605505096331
#> }
rescale(test_x, test_y)
#> [1] 207.4550 376.3745 -762.1280 -334.4448 164.2914
rescale_nested(test_x, test_y)
#> [1] 207.4550 376.3745 -762.1280 -334.4448 164.2914
rescale_nested_pe(test_y)
#> [1] 207.4550 376.3745 -762.1280 -334.4448 164.2914
As it turns out, this solution puts an expression in the scope where local variables will be present, so I didn’t have to do all that checking of the scope. So I will remove it again.
save_local_functions <- function(expr, env, params, var, fun_table, ...) {
if (!rlang::is_symbol(expr[[1]]))
return(expr)
if (!rlang::is_call(expr[[3]]) || length(expr[[3]]) < 2 ||
expr[[3]][[1]] != "function")
return(expr)
fun_name <- as.character(expr[[2]])
stopifnot(!exists(fun_name, fun_table))
fun_table[[fun_name]] <- eval(expr[[3]], env)
expr
}
rescale_nested <- function(x, y, m) {
mean <- m
move <- function(x, y) y - mean(x)
move(x, y) / sd(x)
}
rescale_nested %>% pe(x, test_x)
#> function (y, m)
#> {
#> mean <- m
#> move <- function(x, y) y - mean(x)
#> (function (y)
#> y - 0.46)(y = y)/0.330605505096331
#> }
This example, however, illustrates that there is still some work that could be done. We could propagate the local value of mean
into move
. Trying for that, though, leads us into territory I don’t want to enter. It is undecidable to figure out what a local variable is set to at any particular point in a program unless we put some serious restrictions on it or do some serious static analysis to handle special cases where we are able to determine it. For this document, it takes us too far into static analysis, so I will stop here.
In any case, we are now doing so much work with the partial evaluation that we are only making the program slower, and as it is, we only barely got better than the simple rescale
function earlier.
I do not know if there is a use for partial evaluation in R. The language is not built for speed in the first place, and optimisations would be better handled in an adaptive virtual machine. If you can find a use for partial evaluation, though, I think you can implement it using foolbox
.