You have a for
loop, while
loop, or a
do while
loop that mutates an initial value by some
algorithm.
Replace the loop with a fold, passing the initial value and algorithm as arguments.
Loops generally coincide with mutable state, which is not referentially transparent.
Folds not only eliminate mutation, but they're more concise. Less code means fewer opportunities for typos and logic errors.
When encountering a loop over a foldable data structure, note what state it's computing with each iteration, and what the initial state is:
val xs = List(1,2,3,4)
var i = 0
var sum = 0
while (i < xs.length) {
val x = xs(i)
= sum + x
sum = i + 1
i }
println(s"sum: ${sum}") // sum: 10
Here, our loop accumulates a sum of the list values. The initial state of the sum is zero.
Rewrite the loop body as a function taking a tuple of the prior value of the computed state and the current value of the foldable data structure, and returning the next value of the computed stated.
val sumInit = 0
val sumOp: (Int,Int) => Int = { (sum, x) => sum + x }
Fold over the original data structure, passing the initial state value and the rewritten loop body function.
Here, we use a left fold:
foldLeft: List[A] => B => (((B, A) => B) => B)
val foldSum = xs.foldLeft(sumInit)(sumOp)
println(s"fold sum: ${foldSum}") // fold sum: 10
Not all loops are over a foldable data structure like
List
. Sometimes we just loop until some arbitrary condition
is false:
var count = 0
while (count < 10) {
= count + 1
count }
println(s"count: ${count}") // count: 10
Here, we count the number of loop iterations for ten iterations.
In this case, we just need to fold over a data structure of the right length. We don't care about its values.
In this example, we can use a Range
:
val countInit = 0
val countOp: (Int,Int) => Int = { (count, _) => count + 1 }
val foldCount = (0 until 10).foldLeft(countInit)(countOp)
println(s"fold count: ${foldCount}") // fold count: 10
Not all loops reduce a bunch of values into a single result. Sometimes the result is another bunch of values.
var j = xs.length - 1
var sq = List[Int]()
while (j >= 0) {
val x = xs(j)
= (x * x) :: sq
sq = j - 1
j }
println(s"squares: ${sq}") // squares: List(1, 4, 9, 16)
Here, we create a new list containing the squares of the numbers in the old list.
This time, we use a right fold:
foldRight: List[A] => B => (((A, B) => B) => B)
val foldSq = xs.foldRight(List[Int]())((x, z) => (x * x) :: z)
println(s"fold squares: ${foldSq}") // fold squares: List(1, 4, 9, 16)
If we had used a left fold, our result would be backwards:
val foldSqRev = xs.foldLeft(List[Int]())((z, x) => (x * x) :: z)
println(s"fold squares reverse: ${foldSqRev}") // fold squares reverse: List(16, 9, 4, 1)
sum: 10
fold sum: 10
count: 10
fold count: 10
squares: List(1, 4, 9, 16)
fold squares: List(1, 4, 9, 16)
fold squares reverse: List(16, 9, 4, 1)