Replace Dependency Injection with the Reader Monad

Summary

You have functions with different dependencies, and want to compose them.

Remove dependencies from function arguments, and remodel the functions as partially-curried on those dependencies.

Motivation

The reader monad lets us encode dependencies directly into the type, while providing composability.

Mechanics

A reader encapsulates a function, making it composable with other readers.

case class Reader[-E, +A](run: E => A):

  def map[E2 <: E, B](f: A => B): Reader[E2, B] =
    Reader(e => f(run(e)))

  def flatMap[E2 <: E, B](f: A => Reader[E2, B]): Reader[E2, B] =
    Reader(e => f(run(e)).run(e))

Example

case class Folk(id: Int, name: String)

We define a bunch of database functions, each of which needs a database connection value.

trait HasConnection:
  def c: java.sql.Connection

def initDb(e: HasConnection): Unit =
  val s1 =
    e.c.prepareStatement(
      """|CREATE TABLE IF NOT EXISTS FOLKS (
         |  ID INT NOT NULL,
         |  NAME VARCHAR(1024),
         |  PRIMARY KEY (ID)
         |)""".stripMargin
    )
  s1.execute()
  s1.close()

def addFolk(f: Folk)(e: HasConnection): Unit =
  val s2 =
    e.c.prepareStatement(
      """|INSERT INTO FOLKS (ID, NAME)
         |VALUES (?, ?)""".stripMargin
    )
  s2.setInt(1, f.id)
  s2.setString(2, f.name)
  s2.execute()
  s2.close()

def getFolk(id: Int)(e: HasConnection): Option[Folk] =
  val s3 =
    e.c.prepareStatement(
      "SELECT ID, NAME FROM FOLKS WHERE ID = ?"
    )
  s3.setInt(1, id)
  val rs3 = s3.executeQuery()
  val folk =
    if (rs3.next()) then
      Some(Folk(rs3.getInt("ID"), rs3.getString("NAME")))
    else None
  s3.close()
  folk

def showFolk(f: Folk): String =
  s"Folk ${f.id}: ${f.name}"

def close(e: HasConnection): Unit =
  e.c.close()

We also define a generic way to print lines of text. In practice this could go to stdout, or a log file, etc.

trait HasPrintLine:
  def printLine(x: String): Unit

To put it together, we need a database connection value which we pass around to each database function.

def demo(e: HasConnection with HasPrintLine): Unit =
  initDb(e)
  addFolk(Folk(1, "Folky McFolkface"))(e)
  val fO = getFolk(1)(e)
  val sO = fO.map(showFolk)
  sO.foreach(e.printLine)
  close(e)

def e(db: String): HasConnection with HasPrintLine =
  new HasConnection with HasPrintLine:

    Class.forName("org.hsqldb.jdbcDriver")

    override val c: java.sql.Connection =
      java.sql.DriverManager
        .getConnection(s"jdbc:hsqldb:mem:${db}", "sa", "")

    override def printLine(x: String): Unit = println(x)

demo(e("demo"))

We can convert the imperative database functions to readers to hide the database connection values.

val initDbR: Reader[HasConnection, Unit] =
  Reader(initDb(_))

def addFolkR(f: Folk): Reader[HasConnection, Unit] =
  Reader(addFolk(f))

def getFolkR(id: Int): Reader[HasConnection, Option[Folk]] =
  Reader(getFolk(id))

def showFolkR(id: Int): Reader[HasConnection, Option[String]] =
  for
    fO <- getFolkR(id)
    sO = for
      f <- fO
    yield showFolk(f)
  yield sO

val closeR: Reader[HasConnection, Unit] =
  Reader(close)
def printLineR(x: String): Reader[HasPrintLine, Unit] =
  Reader(_.printLine(x))

Converting our demo from before, we never have to directly deal with a database connection value.

val demoR: Reader[HasConnection with HasPrintLine, Unit] =
  for
    _  <- initDbR
    _  <- addFolkR(Folk(1, "Folky McFolkface"))
    fO <- getFolkR(1)
    sO  = fO.map(showFolk)
    _  <- sO match {
            case None => Reader(_ => ())
            case Some(s) => printLineR(s)
          }
    _   = closeR
  yield ()

demoR.run(e("demoR"))

Demo

This file is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/posts/itof/di-to-reader.md |
  codedown scala |
  scala-cli -q --scala 3.1.1 --dep org.hsqldb:hsqldb:2.3.3 _.sc
Folk 1: Folky McFolkface
Folk 1: Folky McFolkface