The State Monad

James Earl Douglas

December 13, 2012

Where we're headed

Monad

trait Monad[A, F[_]] {
  def map[B](f: A => B): F[B]
  def flatMap[B](f: A => F[B]): F[B]
}

For our purposes, a monad abstracts over a type constructor F[_] (for example, List[_]) to provide a couple of functions:

map

Lifts a function A => B into a function F[A] => F[B]

flatMap

Lifts a function A => F[B] into a function F[A] => F[B]

Where we're headed

State monad*

Abstracts over a state-passing function S => (A, S)

Our type constructor here is S => (_, S), where S is an arbitrary type of some state

class StateMonad[S, A](g: S => (A, S)) {
  def map[B](f: A => B): (S => (B, S)) =
    state => {
      val (a, state1) = g(state)
      (f(a), state1)
    }
  def flatMap[B](f: A => (S => (B, S))): (S => (B, S)) =
    state => {
      val (a, state1) = g(state)
      f(a)(state1)
    }
}

Start simple

Some impure code

def parseInt(x: String): Int = x.toInt
def double(x: Int): Int = x * 2
def meaningOfLife(x: Int): Boolean = x == 42

meaningOfLife(double(parseInt("1")))
  // false

meaningOfLife(double(parseInt("21")))
  // true

meaningOfLife(double(parseInt("twenty-one")))
  // exception

The problem

Parse errors cause things to blow up.

Don't throw exceptions

def parseInt(x: String): Int =
  try {
    x.toInt
  } catch {
    case t =>
      println(t.toString)
      0
  }

meaningOfLife(double(parseInt("1")))
  // false

meaningOfLife(double(parseInt("21")))
  // true

meaningOfLife(double(parseInt("twenty-one")))
  // false, and a message is printed to stdout

The problem

This is better, but parse errors still change the state of the world (in this case, stdout).

Save the world

We want a way to describe how we'll change the world without actually doing it. Our function will take in the current state of the world, perform an operation, and return both the result and an updated state of the world

currentWorldState => (computationResult, newWorldState)

For example, a function that returns true, and records a log message:

log => (true, "some message" :: log)

Log the parse error

def parseInt(x: String): List[String] => (Int, List[String]) =
  log =>
    try {
      (x.toInt, log)
    } catch {
      case t =>
        (0, t.toString :: log)
    }

Purity achieved

Now we can keep track of both the parse result and a log entry:

Log other stuff too

def meaningOfLife(x: Int): List[String] => (Boolean, List[String]) =
  log =>
    if (x == 42) {
      (true, "found the meaning of life" :: log)
    } else {
      (false, log)
    }

More purity

As before, we do some computation and maybe log a message, all without mutating anything

Hold the phone

How the heck do we compose this with parseInt and double?

Compose it

implicit def stateMonad[S, A](g: S => (A, S)): StateMonad[S, A] = new StateMonad(g)

parseInt("1").map(double).flatMap(meaningOfLife)
  // a function that needs a log to run

for {
  parsed  <- parseInt("1")
  doubled  = double(parsed)
  meaning <- meaningOfLife(doubled)
} yield meaning
  // equivalent to the above function

parseInt("1").map(double).flatMap(meaningOfLife)(Nil)
  // (false,List())

parseInt("21").map(double).flatMap(meaningOfLife)(Nil)
  // (true,List(found the meaning of life))

parseInt("twenty-one").map(double).flatMap(meaningOfLife)(Nil)
  // (false,List(java.lang.NumberFormatException: For input string: "twenty-one"))

Bonus

Explicitly set and get state

def set[S](s: S): StateMonad[S, S] = new StateMonad(_ => (s, s))
def get[S]: StateMonad[S, S] = new StateMonad(s => (s, s))

for {
  _       <- set(List("foo", "bar"))
  parsed  <- parseInt("1")
  state   <- get[List[String]]
  _        = println("sneaky peek at the state: " + state)
  doubled  = double(parsed)
  meaning <- meaningOfLife(doubled)
} yield meaning

Reference