Trampolines
Some programs are easily expressed in terms of recursion. Recursion has some positive attributes. It's a declarative and concise way of describing some calculations. It works well with immutable data, which are useful for avoiding certain kinds of bugs, especially in concurrent code. It's also a lot of fun!
In languages without built in support for optimizing recursive functions, one runs into an obstacle. Recursive function calls can build large call stacks, and eventually hit a limit or overflow.
Consider the algorithm for computing the product of a collection of numbers:
Π(X...) = X₀ if count(X) = 1
Π(X...) = X₀ * Π(X₁...) if count(X) > 1In Python (for example), we might implement this as
def product(numbers):
if len(numbers) == 1:
return numbers[0]
elif len(numbers) > 1:
return numbers[0] * product(numbers[1:])
Each recursive function call needs to complete before the calling function can
return, and there's a limit to how many incomplete function evaluations we can
maintain.In Python, that limit is 1000 by default. Calling
product(list(range(1, 1000))) will raise a RecursionError.
If we need to compute products of lists with many elements, we could flatten our algorithm into a procedural one. To see how this works, let's first translate our function to use an accumulator argument, so that the recursive call is just a call to itself (rather than a multiplication):
def product_tailcall(acc, numbers):
if len(numbers) == 1:
return acc * numbers[0]
elif len(numbers) > 1:
return product(acc * numbers[0], numbers[1:])
If we start with acc = 1 (since 1 is the multiplicative identity), then this
gives the same result as product, but no longer strictly needs to retain the
intermediate function evaluations.Although, in CPython it still does
unfortunately. It's clear how this can be flattened into a loop:
def product_loop(numbers):
acc = 1
remainder = numbers
while len(remainder) > 0:
acc *= remainder[0]
remainder = remainder[1:]
return acc
That works, but I find it a bit unsatisfactory. We've lost some of the elegance of the first algorithm, and it's harder to see that it corresponds to the original mathematical notation. Can we do something to automatically transform the first algorithm into the second?
We can, using a trick called trampolining! The idea is to postpone the
evaluation of the recursive function, so that we can work with it as data.
First, let's define a class to represent deferred computation, which I'll call
Suspend because it suspends the evaluation of a value:
class Suspend:
pass
It will have two subclasses for now, Value and Bind:
@dataclasses.dataclass
class Value(Suspend):
value: Any
@dataclasses.dataclass
class Bind(Suspend):
func: Callable[[Any], Suspend]
over: Suspend
Value represents a pure value, like a number or a string. Bind represents
the postponed application of a function to a value, and it allows us to bind a
function call to a value without evaluating it and adding it to the call stack.
In order to use this, we need to write a way to evaluate this type and obtain the result it describes. We can define some generic unrolling code:
def evaluate(program):
current = program
while isinstance(current, Suspend):
if isinstance(current, Value):
return current.value
elif isinstance(current, Bind):
arg = evaluate(current.over)
current = current.func(arg)
continue
raise TypeError
The input to evaluate is a suspended calculation. The suspended calculation
can be recursive, because it's not evaluated directly by the Python interpreter.
Instead, the recursion flattened handled by the while loop. The loop acts like
our call stack, but there's no limit to the number of repetitions it can
perform.At this stage, we've implemented our own very simple
interpreter.
We can write the original recursive algorithm using our new classes. (I've
introduced functools.partial, but one could also use default arguments or
curry the function instead.)
from functools import partial
def product_suspended(acc, numbers):
if len(numbers) == 1:
return Value(numbers[0] * acc)
else:
num = numbers[0]
rest = numbers[1:]
return Bind(partial(product_suspended, acc * num), Value(rest))
numbers = list(range(1, 2000))
program = Bind(partial(product_suspended, 1), Value(numbers))
# Prove that it works!
assert(evaluate(program) == product_loop(numbers))
Now, arguably the extra machinery that we've added has made this (non-Pythonic)
code a bit harder to understand. I think this is still a neat exercise though,
because it's a basic demonstration of how we can create a frameworks for
describing programs that wouldn't be possible or practical to execute directly.
The technique can be employed when it's desirable to describe a computation but
be clever about how it's actually performed. For example we might want to
implement optimizations in the evaluate procedure, or take the opportunity to
replace calculations described in our high-level language with faster
implementations in a different language.