Not in the Box; in the BandT

January 22, 2017

This is a quick refresher on how monad transformers work in Scala.

Monad

trait Monad[F[_]] {
  def map[A,B](f: A => B)(fa: F[A]): F[B]
  def flatMap[A,B](fa: F[A])(f: A => F[B]): F[B]
}

trait Monadic[A,F[_]] {
  def map[B](f: A => B): F[B]
  def flatMap[B](f: A => F[B]): F[B]
}

Identity

type Id[A] = A

implicit object IdMonad extends Monad[Id] {
  def map[A,B](f: A => B)(fa: Id[A]): Id[B] = f(fa)
  def flatMap[A,B](fa: Id[A])(f: A => Id[B]): Id[B] = f(fa)
}

implicit class IdMonadic[A](fa: Id[A]) extends Monadic[A,Id] {
  def map[B](f: A => B): Id[B] = IdMonad.map(f)(fa)
  def flatMap[B](f: A => Id[B]): Id[B] = IdMonad.flatMap(fa)(f)
}

Box

case class Box[A](a: A)

implicit object BoxMonad extends Monad[Box] {
  def map[A,B](f: A => B)(fa: Box[A]): Box[B] = {
    println(s"*** Box.map(f)(${fa})")
    Box(f(fa.a))
  }
  def flatMap[A,B](fa: Box[A])(f: A => Box[B]): Box[B] = f(fa.a)
}

implicit class BoxMonadic[A](fa: Box[A]) extends Monadic[A,Box] {
  def map[B](f: A => B): Box[B] = BoxMonad.map(f)(fa)
  def flatMap[B](f: A => Box[B]): Box[B] = BoxMonad.flatMap(fa)(f)
}

Band and its transformer

case class BandT[A,F[_]](fa: F[A])

object BandT {
  def lift[A,F[_]:Monad](fa: F[A]): BandT[A,F] = BandT(fa)
}

implicit def bandTMonad[F[_]:Monad] =
  new Monad[({type λ[α] = BandT[α,F]})#λ] {
    def map[A,B](f: A => B)(fa: BandT[A,F]): BandT[B,F] =
      BandT(implicitly[Monad[F]].map(f)(fa.fa))
    def flatMap[A,B](fa: BandT[A,F])(f: A => BandT[B,F]): BandT[B,F] =
      BandT(implicitly[Monad[F]].flatMap(fa.fa)(a => f(a).fa))
  }

implicit class BandTMonadic[A,F[_]:Monad](fa: BandT[A,F])
  extends Monadic[A,({type λ[α] = BandT[α,F]})#λ] {
    def map[B](f: A => B): BandT[B,F] =
      bandTMonad[F].map(f)(fa)
    def flatMap[B](f: A => BandT[B,F]): BandT[B,F] =
      bandTMonad[F].flatMap(fa)(f)
  }

type Band[A] = BandT[A,Id]

object Band {
  def apply[A](a: A): Band[A] = BandT[A,Id](a)
}

Demo

println(
  for {
    x <- Box(6)
    y <- Box(7)
  } yield (x * y)
)
println(
  for {
    x <- Band(6)
    y <- Band(7)
  } yield (x * y)
)
println(
  for {
    x <- BandT(Box(6))
    y <- BandT.lift(Box(7))
  } yield (x * y)
)

Usage

$ curl -sL earldouglas.com/posts/box-band.md | codedown scala > box-band.scala
$ scala box-band.scala
*** Box.map(f)(Box(7))
Box(42)
BandT(42)
*** Box.map(f)(Box(7))
BandT(Box(42))