Suspendable recursion with Scala continuations

September 16, 2011

Consider the following recursive Pi estimation method:

def pi(n: Int = 0): Double =
  4 * math.pow(-1, n) / (2 * n + 1) + pi(n + 1)

Having no defined end, this method won't do much aside from heating up the CPU until it eventually blows the stack. Sometimes it can be useful to peek into the progress of such a recursive method, which can be done by suspending it in a continuation.

def pi(n: Int = 0): Double @dd =
  4 * math.pow(-1, n) / (2 * n + 1) + suspend(pi(n + 1))

Here, the suspend function takes the rest of the execution, which includes the recursive call(s) to pi, and puts it off to the side as a continuation. Meanwhile, the current estimate of Pi is returned by the pi function. Resuming the continuation will refine the estimate of Pi by one iteration, and again sidebar the next recursive call to pi. The @dd annotation is simply an alias for the otherwise wordy @cpsParam[Double, Double].

This behavior is accessed via the calc() function in the following object:

import scala.util.continuations._

object InterrupterJones {

  type dd = cpsParam[Double, Double]

  def calc() = reset(next())

  private var next: () => Double @dd =
    () => pi()

  private def pi(n: Int = 0): Double @dd =
    4 * math.pow(-1, n) / (2 * n + 1) + suspend(pi(n + 1))

  private def suspend(f: => Double @dd): Double @dd =
    shift { k: (Double => Double) =>
      val curr = k(0.0)
      next = () => curr + f

An example interaction looks something like this:

$ scala -P:continuations:enable
Welcome to Scala version (OpenJDK Server VM, Java 1.6.0_22).
Type in expressions to have them evaluated.
Type :help for more information.

scala> :load InterrupterJones.scala
Loading InterrupterJones.scala...
import scala.util.continuations._
defined module InterrupterJones

scala> InterrupterJones.calc()
res0: Double = 4.0

scala> InterrupterJones.calc()
res1: Double = 2.666666666666667

scala> InterrupterJones.calc()
res2: Double = 3.466666666666667

scala> InterrupterJones.calc()
res3: Double = 2.8952380952380956

Each call to calc() iterates the recursive pi function once, slightly improving the estimate of Pi, and updating the continuation to be run upon the next call to calc(). Furthermore, the pi function can now be called repeatedly without worrying about a stack overflow:

scala> (1 to 1000000).map(_ => InterrupterJones.calc()).reverse.head
res4: Double = 3.1415916535937742