Replace dependency injection with the reader monad

Summary

You have multiple functions that reuse an argument, either directly or to pass along to another function call.

Remove the argument from functions that don't need it, and remodel functions that do need it as partially-curried on that argument.

val PI = 3.14

def area(pi: Double, r: Int): Double = pi * r * r
def volume(pi: Double, r: Int, h: Int): Double = {
  val a = area(pi, r)
  val v = a * h
  v
}
val vol = volume(PI, 6, 7)
println(s"vol: ${vol}") // vol: 791.28

We use the following form for our reader monad:

case class Reader[E,A](run: E => A) {
  def flatMap[B](f: A => Reader[E,B]): Reader[E,B] =
    Reader[E,B] { e =>
      f(run(e)).run(e)
    }
  def map[B](f: A => B): Reader[E,B] =
    Reader[E,B] { e =>
      f(run(e))
    }
}

Now we can extract Pi as a dependency:

def areaR(r: Int): Reader[Double,Double] =
  Reader { pi => pi * r * r }

def volumeR(r: Int, h: Int): Reader[Double,Double] =
  areaR(r) map { a => a * h }
val r = 6
val h = 7

val volR = volumeR(r, h) run PI
println(s"volR: ${volR}") // volR: 791.28

The volumeR reader still has an unnecessary dependency on the radius argument, which we can extract as well:

def volumeRR(h: Int): Reader[Double,Reader[Double,Double]] =
  Reader { a => areaR(r) map { a => a * h } }
val volRR = volumeRR(h) run r run PI
println(s"volRR: ${volRR}") // volRR: 791.28

Motivation

When dependencies are injected at the top level, they have to be passed around from one function to another, regardless of whether or not any one particular function actually uses them.

The reader monad lets us encode dependencies directly into the type, while providing composability.

Mechanics

When encountering functions that pass an argument around to each other, extract the common dependency, convert the blocks that need it to readers of that dependency, and put them back together using map and flatMap.

Example

case class Folk(id: Int, name: String)

We define a bunch of database functions, each of which needs a database connection value.

def initDb(c: java.sql.Connection): Unit = {
  val s1 =
    c.prepareStatement(
      """|CREATE TABLE IF NOT EXISTS FOLKS (
         |  ID INT NOT NULL,
         |  NAME VARCHAR(1024),
         |  PRIMARY KEY (ID)
         |)""".stripMargin
    )
  s1.execute()
  s1.close()
}

def addFolk(f: Folk)(c: java.sql.Connection): Unit = {
  val s2 =
    c.prepareStatement(
      """|INSERT INTO FOLKS (ID, NAME)
         |VALUES (?, ?)""".stripMargin
    )
  s2.setInt(1, f.id)
  s2.setString(2, f.name)
  s2.execute()
  s2.close()
}

def getFolk(id: Int)(c: java.sql.Connection): Option[Folk] = {
  val s3 =
    c.prepareStatement(
      "SELECT ID, NAME FROM FOLKS WHERE ID = ?"
    )
  s3.setInt(1, id)
  val rs3 = s3.executeQuery()
  val folk =
    if (rs3.next()) {
      Some(Folk(rs3.getInt("ID"), rs3.getString("NAME")))
    } else {
      None
    }
  s3.close()
  folk
}

def showFolk(f: Folk): String =
  s"Folk ${f.id}: ${f.name}"

def close(c: java.sql.Connection): Unit =
  c.close()

To put it together, we need a database connection value which we pass around to each database function.

def demo(c: java.sql.Connection): Unit = {
  initDb(c)
  addFolk(Folk(1, "Folky McFolkface"))(c)
  val fO = getFolk(1)(c)
  val sO = fO map showFolk
  sO foreach println
  close(c)
}

Class.forName("org.hsqldb.jdbcDriver")
demo(java.sql.DriverManager.getConnection("jdbc:hsqldb:mem:demo", "sa", ""))

Reader monad

A reader encapsulates a function, making it composable with other readers.

We can convert the imperative database functions to readers to hide the database connection values.

val initDbR: Reader[java.sql.Connection,Unit] =
  Reader(initDb(_))

def addFolkR(f: Folk): Reader[java.sql.Connection,Unit] =
  Reader(addFolk(f))

def getFolkR(id: Int): Reader[java.sql.Connection,Option[Folk]] =
  Reader(getFolk(id))

def showFolkR(id: Int): Reader[java.sql.Connection,Option[String]] =
  for {
    fO <- getFolkR(id)
    sO  = for {
           f <- fO
          } yield showFolk(f)
  } yield sO

val closeR: Reader[java.sql.Connection,Unit] =
  Reader(close)

Converting our demo from before, we never have to directly deal with a database connection value.

val demoR: Reader[java.sql.Connection,Unit] =
  for {
    _  <- initDbR
    _  <- addFolkR(Folk(1, "Folky McFolkface"))
    fO <- getFolkR(1)
    sO  = fO map showFolk
    _   = sO foreach println
    _   = closeR
  } yield ()

Class.forName("org.hsqldb.jdbcDriver")
demoR.run(java.sql.DriverManager.getConnection("jdbc:hsqldb:mem:demoR", "sa", ""))

Demo

/***
libraryDependencies += "org.hsqldb" % "hsqldb" % "2.3.3"
*/
$ curl -s https://earldouglas.com/posts/itof/di-to-reader.md | codedown scala > di-to-reader.scala
$ sbt -Dsbt.main.class=sbt.ScriptMain di-to-reader.scala
vol: 791.28
volR: 791.28
volRR: 791.28
Folk 1: Folky McFolkface
Folk 1: Folky McFolkface