Home

Trampolines

Some programs are easily expressed in terms of recursion. Recursion has some positive attributes. It's a declarative and concise way of describing some calculations. It works well with immutable data, which are useful for avoiding certain kinds of bugs, especially in concurrent code. It's also a lot of fun!

In languages without built in support for optimizing recursive functions, one runs into an obstacle. Recursive function calls can build large call stacks, and eventually hit a limit or overflow.

Consider the algorithm for computing the product of a collection of numbers:

Π(X...) = X₀            if count(X) = 1
Π(X...) = X₀ * Π(X₁...) if count(X) > 1

In Python (for example), we might implement this as

def product(numbers):
  if len(numbers) == 1:
    return numbers[0]
  elif len(numbers) > 1:
    return numbers[0] * product(numbers[1:])

Each recursive function call needs to complete before the calling function can return, and there's a limit to how many incomplete function evaluations we can maintain.In Python, that limit is 1000 by default. Calling product(list(range(1, 1000))) will raise a RecursionError.

If we need to compute products of lists with many elements, we could flatten our algorithm into a procedural one. To see how this works, let's first translate our function to use an accumulator argument, so that the recursive call is just a call to itself (rather than a multiplication):

def product_tailcall(acc, numbers):
  if len(numbers) == 1:
    return acc * numbers[0]
  elif len(numbers) > 1:
    return product(acc * numbers[0], numbers[1:])

If we start with acc = 1 (since 1 is the multiplicative identity), then this gives the same result as product, but no longer strictly needs to retain the intermediate function evaluations.Although, in CPython it still does unfortunately. It's clear how this can be flattened into a loop:

def product_loop(numbers):
  acc = 1
  remainder = numbers
  while len(remainder) > 0:
    acc *= remainder[0]
    remainder = remainder[1:]
  return acc

That works, but I find it a bit unsatisfactory. We've lost some of the elegance of the first algorithm, and it's harder to see that it corresponds to the original mathematical notation. Can we do something to automatically transform the first algorithm into the second?

We can, using a trick called trampolining! The idea is to postpone the evaluation of the recursive function, so that we can work with it as data. First, let's define a class to represent deferred computation, which I'll call Suspend because it suspends the evaluation of a value:

class Suspend:
  pass

It will have two subclasses for now, Value and Bind:

@dataclasses.dataclass
class Value(Suspend):
  value: Any

@dataclasses.dataclass
class Bind(Suspend):
  func: Callable[[Any], Suspend]
  over: Suspend

Value represents a pure value, like a number or a string. Bind represents the postponed application of a function to a value, and it allows us to bind a function call to a value without evaluating it and adding it to the call stack.

In order to use this, we need to write a way to evaluate this type and obtain the result it describes. We can define some generic unrolling code:

def evaluate(program):
  current = program
  while isinstance(current, Suspend):
    if isinstance(current, Value):
      return current.value
    elif isinstance(current, Bind):
      arg = evaluate(current.over)
      current = current.func(arg)
      continue
  raise TypeError

The input to evaluate is a suspended calculation. The suspended calculation can be recursive, because it's not evaluated directly by the Python interpreter. Instead, the recursion flattened handled by the while loop. The loop acts like our call stack, but there's no limit to the number of repetitions it can perform.At this stage, we've implemented our own very simple interpreter.

We can write the original recursive algorithm using our new classes. (I've introduced functools.partial, but one could also use default arguments or curry the function instead.)

from functools import partial

def product_suspended(acc, numbers):
  if len(numbers) == 1:
    return Value(numbers[0] * acc)
  else:
    num = numbers[0]
    rest = numbers[1:]
    return Bind(partial(product_suspended, acc * num), Value(rest))

numbers = list(range(1, 2000))
program = Bind(partial(product_suspended, 1), Value(numbers))

# Prove that it works!
assert(evaluate(program) == product_loop(numbers))

Now, arguably the extra machinery that we've added has made this (non-Pythonic) code a bit harder to understand. I think this is still a neat exercise though, because it's a basic demonstration of how we can create a frameworks for describing programs that wouldn't be possible or practical to execute directly. The technique can be employed when it's desirable to describe a computation but be clever about how it's actually performed. For example we might want to implement optimizations in the evaluate procedure, or take the opportunity to replace calculations described in our high-level language with faster implementations in a different language.