# Trampolines

Some programs are easily expressed in terms of recursion. Recursion has some positive attributes. It's a declarative and concise way of describing some calculations. It works well with immutable data, which are useful for avoiding certain kinds of bugs, especially in concurrent code. It's also a lot of fun!

In languages without built in support for optimizing recursive functions, one runs into an obstacle. Recursive function calls can build large call stacks, and eventually hit a limit or overflow.

Consider the algorithm for computing the product of a collection of numbers:

``````Π(X...) = X₀            if count(X) = 1
Π(X...) = X₀ * Π(X₁...) if count(X) > 1``````

In Python (for example), we might implement this as

``````def product(numbers):
if len(numbers) == 1:
return numbers
elif len(numbers) > 1:
return numbers * product(numbers[1:])``````

Each recursive function call needs to complete before the calling function can return, and there's a limit to how many incomplete function evaluations we can maintain.In Python, that limit is 1000 by default. Calling `product(list(range(1, 1000)))` will raise a `RecursionError`.

If we need to compute products of lists with many elements, we could flatten our algorithm into a procedural one. To see how this works, let's first translate our function to use an accumulator argument, so that the recursive call is just a call to itself (rather than a multiplication):

``````def product_tailcall(acc, numbers):
if len(numbers) == 1:
return acc * numbers
elif len(numbers) > 1:
return product(acc * numbers, numbers[1:])``````

If we start with `acc = 1` (since 1 is the multiplicative identity), then this gives the same result as `product`, but no longer strictly needs to retain the intermediate function evaluations.Although, in CPython it still does unfortunately. It's clear how this can be flattened into a loop:

``````def product_loop(numbers):
acc = 1
remainder = numbers
while len(remainder) > 0:
acc *= remainder
remainder = remainder[1:]
return acc``````

That works, but I find it a bit unsatisfactory. We've lost some of the elegance of the first algorithm, and it's harder to see that it corresponds to the original mathematical notation. Can we do something to automatically transform the first algorithm into the second?

We can, using a trick called trampolining! The idea is to postpone the evaluation of the recursive function, so that we can work with it as data. First, let's define a class to represent deferred computation, which I'll call `Suspend` because it suspends the evaluation of a value:

``````class Suspend:
pass``````

It will have two subclasses for now, `Value` and `Bind`:

``````@dataclasses.dataclass
class Value(Suspend):
value: Any

@dataclasses.dataclass
class Bind(Suspend):
func: Callable[[Any], Suspend]
over: Suspend``````

`Value` represents a pure value, like a number or a string. `Bind` represents the postponed application of a function to a value, and it allows us to bind a function call to a value without evaluating it and adding it to the call stack.

In order to use this, we need to write a way to evaluate this type and obtain the result it describes. We can define some generic unrolling code:

``````def evaluate(program):
current = program
while isinstance(current, Suspend):
if isinstance(current, Value):
return current.value
elif isinstance(current, Bind):
arg = evaluate(current.over)
current = current.func(arg)
continue
raise TypeError``````

The input to `evaluate` is a suspended calculation. The suspended calculation can be recursive, because it's not evaluated directly by the Python interpreter. Instead, the recursion flattened handled by the `while` loop. The loop acts like our call stack, but there's no limit to the number of repetitions it can perform.At this stage, we've implemented our own very simple interpreter.

We can write the original recursive algorithm using our new classes. (I've introduced `functools.partial`, but one could also use default arguments or curry the function instead.)

``````from functools import partial

def product_suspended(acc, numbers):
if len(numbers) == 1:
return Value(numbers * acc)
else:
num = numbers
rest = numbers[1:]
return Bind(partial(product_suspended, acc * num), Value(rest))

numbers = list(range(1, 2000))
program = Bind(partial(product_suspended, 1), Value(numbers))

# Prove that it works!
assert(evaluate(program) == product_loop(numbers))``````

Now, arguably the extra machinery that we've added has made this (non-Pythonic) code a bit harder to understand. I think this is still a neat exercise though, because it's a basic demonstration of how we can create a frameworks for describing programs that wouldn't be possible or practical to execute directly. The technique can be employed when it's desirable to describe a computation but be clever about how it's actually performed. For example we might want to implement optimizations in the `evaluate` procedure, or take the opportunity to replace calculations described in our high-level language with faster implementations in a different language.