Replace Loops with Folds


You have a for loop, while loop, or a do while loop that mutates an initial value by some algorithm.

Replace the loop with a fold, passing the initial value and algorithm as arguments.


Loops generally coincide with mutable state, which is not referentially transparent.

Folds not only eliminate mutation, but they're more concise. Less code means fewer opportunities for typos and logic errors.


When encountering a loop over a foldable data structure, note what state it's computing with each iteration, and what the initial state is:

val xs = List(1,2,3,4)
var i = 0
var sum = 0
while (i < xs.length) {
  val x = xs(i)
  sum = sum + x
  i = i + 1
println(s"sum: ${sum}") // sum: 10

Here, our loop accumulates a sum of the list values. The initial state of the sum is zero.

Rewrite the loop body as a function taking a tuple of the prior value of the computed state and the current value of the foldable data structure, and returning the next value of the computed stated.

val sumInit = 0
val sumOp: (Int,Int) => Int = { (sum, x) => sum + x }

Fold over the original data structure, passing the initial state value and the rewritten loop body function.

Here, we use a left fold:

foldLeft: List[A] => B => (((B, A) => B) => B)
val foldSum = xs.foldLeft(sumInit)(sumOp)
println(s"fold sum: ${foldSum}") // fold sum: 10

Example: looping without a foldable data structure

Not all loops are over a foldable data structure like List. Sometimes we just loop until some arbitrary condition is false:

var count = 0
while (count < 10) {
  count = count + 1
println(s"count: ${count}") // count: 10

Here, we count the number of loop iterations for ten iterations.

In this case, we just need to fold over a data structure of the right length. We don't care about its values.

In this example, we can use a Range:

val countInit = 0
val countOp: (Int,Int) => Int = { (count, _) => count + 1 }
val foldCount = (0 until 10).foldLeft(countInit)(countOp)
println(s"fold count: ${foldCount}") // fold count: 10

Example: square each number in a list

Not all loops reduce a bunch of values into a single result. Sometimes the result is another bunch of values.

var j = xs.length - 1
var sq = List[Int]()
while (j >= 0) {
  val x = xs(j)
  sq = (x * x) :: sq
  j = j - 1
println(s"squares: ${sq}") // squares: List(1, 4, 9, 16)

Here, we create a new list containing the squares of the numbers in the old list.

This time, we use a right fold:

foldRight: List[A] => B => (((A, B) => B) => B)
val foldSq = xs.foldRight(List[Int]())((x, z) => (x * x) :: z)
println(s"fold squares: ${foldSq}") // fold squares: List(1, 4, 9, 16)

If we had used a left fold, our result would be backwards:

val foldSqRev = xs.foldLeft(List[Int]())((z, x) => (x * x) :: z)
println(s"fold squares reverse: ${foldSqRev}") // fold squares reverse: List(16, 9, 4, 1)


sum: 10
fold sum: 10
count: 10
fold count: 10
squares: List(1, 4, 9, 16)
fold squares: List(1, 4, 9, 16)
fold squares reverse: List(16, 9, 4, 1)