Functional Refactoring

January 24, 2015

Refactoring provides an accessible opportunity to learn about imperative and functional design patterns, and Scala's hybrid OO/FP design caters to both.

We explore examples of Scala code written using familiar imperative design patterns, and refactor each one using a counterpart functional pattern. We learn how to replace mutable variables with the state monad, loops with folds, thrown exceptions with sum types, dependency injection with the reader monad, and much more. As we go, we build a mapping of corresponding imperative and functional patterns.

Selections from this page were presented at LambdaConf on May 28, 2016 in Boulder: Video, Slides

Definitions

Functional

adjective

  1. referentially transparent.

Referentially transparent

We use RĂșnar Bjarnason's definition of referential transparency

adjective

  1. an expression e is referentially transparent if for all programs p, every occurrence of e in p can be replaced with the result of evaluating e without changing the result of evaluating p.

Refactoring

We use Martin Fowler's definitions of refactoring

noun

  1. a change made to the internal structure of software to make it easier to understand and cheaper to modify without changing its observable behavior.

verb

  1. to restructure software by applying a series of refactorings without changing its observable behavior.

Table of contents

  1. Replace mutator method with deep copy
  2. Replace mutable variables with the state monad
  3. Replace loops with folds
  4. Replace exceptions with sum types
  5. Replace annotation injection with arguments
  6. Replace pass-through arguments with the reader monad
  7. Replace dependency injection with the reader monad
  8. Differentiate values with newtype
  9. Replace code with documentation
  10. Wrap side effects with IO

Drafts

  1. Use the writer monad for logging
  2. Wrap async callbacks with futures
  3. Wrap runtime exceptions in Try
  4. Replace getters and setters with lenses
  5. Convert loops into recursive calls
  6. Replace mutable fields with defensive copy
  7. Replace function with data structure
  8. Replace goto with shift and reset
  9. Wrap nullable values in Option
  10. Replace inheritance with type classes
  11. Parse with validation
  12. Progressions of purity
  13. Parametricity
  14. Progressions of concision
  15. Move dependency injection from run-time to compile-time

Backlog

  1. Replace random access with zipper
  2. Replace accessors and mutators with lenses
  3. Replace malloc and free with a monad
  4. Replace pass-through variables with the reader monad

References


Replace mutator method with deep copy

Summary

You have a data structure with internal state that can be mutated in place by a method call.

Make the internal state immutable, and convert the mutator into a method that makes a deep copy of the data structure, with the corresponding mutation applied.

Mutable mutability

class MutableEmployee(var name: String, var title: String) {
  def setName(_name: String): Unit = {
    name = _name
  }
  def setTitle(_title: String): Unit = {
    title = _title
  }
}

Immutable mutability

class ImmutableEmployee(name: String, title: String) {
  def setName(_name: String): ImmutableEmployee =
    new ImmutableEmployee(_name, title)
  def setTitle(_title: String): ImmutableEmployee =
    new ImmutableEmployee(name, _title)
}

This is a functional take on the builder pattern.

Motivation

Mutable variables tend to preclude referential transparency, and can lead to all kind of bugs related to timing, threading, parallelism, lack of idempotency, evaluation order, lack of equational reasoning, etc.

class MEmployee(var name: String, var title: String)

val employee0 = new MEmployee("George Michael", "Employee")
println(s"employee0: ${employee0.name}, ${employee0.title}")
// employee0: George Michael, Employee

employee0.title = "Mr. Manager"
println(s"employee0: ${employee0.name}, ${employee0.title}")
// employee0: George Michael, Mr. Manager

Mechanics

Make internal state final. In Scala, this means simply replacing var with val.

Make mutator methods, which probably have no useful return information, return the type of the data structure.

Change the implementations of mutator methods to construct new instances of the data structure, propagating the existing state plus the change(s) implied by the mutator.

Example

In Scala, we get this pattern for free when we use case classes:

case class CCEmployee(name: String, title: String)

val employee1 = CCEmployee("George Michael", "Employee")
val employee2 = employee1.copy(title = "Mr. Manager")

println(s"employee1: ${employee1}")
println(s"employee2: ${employee2}")

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Replace mutator method with deep copy' |
  scala-cli --scala 2.13 -
employee0: George Michael, Employee
employee0: George Michael, Mr. Manager
employee1: CCEmployee(George Michael,Employee)
employee2: CCEmployee(George Michael,Mr. Manager)

Replace mutable variables with the state monad

Summary

You have a mutable variable (a var in Scala) that is both read from and written to outside of a tightly-scoped block.

Remodel the block as functions that take an initial value of the variable, and return both the final value of the variable and the original return value of the expression.

var x = 6
println(s"x = ${x}") // x = 6

x = x * 7
println(s"x = ${x}") // x = 42

We use the following form for our state monad:

case class State[A,S](val run: S => (A,S)) {
  def flatMap[B](f: A => State[B,S]): State[B,S] =
    State { s =>
      val (a,s2) = run(s)
      f(a).run(s2)
    }
}

We can use flatMap to implement other helpful methods:

implicit class StateExtras[A,S](s: State[A,S]) {
  def andThen[B](x: State[B,S]): State[B,S] =
    s.flatMap { _ => x }
  def map[B](f: A => B): State[B,S] =
    s.flatMap { a =>
      State { s =>
        (f(a),s)
      }
    }
}
val    six: State[Unit,Int] = State { _ => (                  (),  6) }
val  print: State[Unit,Int] = State { x => (println(s"x = ${x}"),  x) }
val times7: State[Unit,Int] = State { x => (                  (),x*7) }

val sixBy7: State[Unit,Int] = six andThen print andThen times7 andThen print

sixBy7 run -999 // x = 6
                // x = 42

Motivation

Mutable variables tend to preclude referential transparency, and can lead to all kind of bugs related to timing, threading, parallelism, lack of idempotency, evaluation order, lack of equational reasoning, etc.

The state monad models a mutable state change in a referentially transparent way. With it we can represent data mutation as a functional data structure.

Mechanics

When encountering a mutable variable, note where it's read from and written to:

var greeting: String = ""

val user = System.getenv("USER")
greeting = s"Greetings, ${user}!"

val osName = System.getProperty("os.name")
greeting = s"${greeting}  ${osName}, eh?  Solid."

println(greeting)

Modularize these reads and writes into discrete functions.

Each function takes an incoming version of the variable, and returns some value along with an outgoing version of the variable.

def getUser(greeting: String): (String,String) = {
  val user = System.getenv("USER")
  (user, s"Greetings, ${user}!")
}

def getOSName(greeting: String): (String,String) = {
  val osName = System.getProperty("os.name")
  (osName, s"${greeting}  ${osName}, eh?  Solid.")
}

def getGreeting(greeting: String): (String,String) =
  (greeting, greeting)

To make these explicitly composable, wrap each using the state monad:

val getUserS: State[String,String] =
  State { getUser }

val getOSNameS: State[String,String] =
  State { getOSName }

val getGreetingS: State[String,String] =
  State { getGreeting }

getUserS andThen getOSNameS andThen getGreetingS map println run ""

Example: stateful references

When it's preferable to code in terms of mutable references, we can use a data structure that wraps a mutable variable and represents its access and mutation using the state monad.

This data structure is frequently called STRef, or StateRef:

class StateRef[A,S](_a: A) {

  private var a: A = _a

  def read: State[A,S] =
    State(s => (a,s))

  def write(a2: A): State[Unit,S] =
    State({s =>
      a = a2
      ((),s)
    })

}

object StateRef {
  def apply[S,A](a: A): State[StateRef[A,S],S] =
    State(s => (new StateRef(a),s))
}

We can use one or more StateRefs together to code in terms of mutable references, while keeping the desired characteristics of the state monad:

val nameS: State[String,Unit] =
  for {
    firstSR  <- StateRef("James")
    middleSR <- StateRef("Earl")
    lastSR   <- StateRef("Douglas")
    nameSR   <- StateRef("")
    first    <- firstSR.read
    middle   <- middleSR.read
    _        <- middleSR.write("Buster") // does not overwrite `middle`
    last     <- lastSR.read
    _        <- nameSR.write(s"${first} ${middle} ${last}")
    name     <- nameSR.read
  } yield name

nameS map println run ()

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Replace mutable variables with the state monad' |
  scala-cli --scala 2.13 -
x = 6
x = 42
x = 6
x = 42
Greetings, james!  Linux, eh?  Solid.
Greetings, james!  Linux, eh?  Solid.
James Earl Douglas

Replace loops with folds

Summary

You have a for loop, while loop, or a do while loop that mutates an initial value by some algorithm.

Replace the loop with a fold, passing the initial value and algorithm as arguments.

Motivation

Loops generally coincide with mutable state, which is not referentially transparent.

Folds not only eliminate mutation, but they're more concise. Less code means fewer opportunities for typos and logic errors.

Mechanics

When encountering a loop over a foldable data structure, note what state it's computing with each iteration, and what the initial state is:

val xs = List(1,2,3,4)
var i = 0
var sum = 0
while (i < xs.length) {
  val x = xs(i)
  sum = sum + x
  i = i + 1
}
println(s"sum: ${sum}") // sum: 10

Here, our loop accumulates a sum of the list values. The initial state of the sum is zero.

Rewrite the loop body as a function taking a tuple of the prior value of the computed state and the current value of the foldable data structure, and returning the next value of the computed stated.

val sumInit = 0
val sumOp: (Int,Int) => Int = { (sum, x) => sum + x }

Fold over the original data structure, passing the initial state value and the rewritten loop body function.

Here, we use a left fold:

foldLeft: List[A] => B => (((B, A) => B) => B)
val foldSum = xs.foldLeft(sumInit)(sumOp)
println(s"fold sum: ${foldSum}") // fold sum: 10

Example: looping without a foldable data structure

Not all loops are over a foldable data structure like List. Sometimes we just loop until some arbitrary condition is false:

var count = 0
while (count < 10) {
  count = count + 1
}
println(s"count: ${count}") // count: 10

Here, we count the number of loop iterations for ten iterations.

In this case, we just need to fold over a data structure of the right length. We don't care about its values.

In this example, we can use a Range:

val countInit = 0
val countOp: (Int,Int) => Int = { (count, _) => count + 1 }
val foldCount = (0 until 10).foldLeft(countInit)(countOp)
println(s"fold count: ${foldCount}") // fold count: 10

Example: square each number in a list

Not all loops reduce a bunch of values into a single result. Sometimes the result is another bunch of values.

var j = xs.length - 1
var sq = List[Int]()
while (j >= 0) {
  val x = xs(j)
  sq = (x * x) :: sq
  j = j - 1
}
println(s"squares: ${sq}") // squares: List(1, 4, 9, 16)

Here, we create a new list containing the squares of the numbers in the old list.

This time, we use a right fold:

foldRight: List[A] => B => (((A, B) => B) => B)
val foldSq = xs.foldRight(List[Int]())((x, z) => (x * x) :: z)
println(s"fold squares: ${foldSq}") // fold squares: List(1, 4, 9, 16)

If we had used a left fold, our result would be backwards:

val foldSqRev = xs.foldLeft(List[Int]())((z, x) => (x * x) :: z)
println(s"fold squares reverse: ${foldSqRev}") // fold squares reverse: List(16, 9, 4, 1)

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Replace loops with folds' |
  scala-cli --scala 2.13 -
sum: 10
fold sum: 10
count: 10
fold count: 10
squares: List(1, 4, 9, 16)
fold squares: List(1, 4, 9, 16)
fold squares reverse: List(16, 9, 4, 1)

Replace exceptions with sum types

Summary

You have a function that either returns a value or throws an exception.

Change the function's return type to a disjunction of the original return type and the exception.

We use the following form for our disjunction, called Either:

sealed trait Either[A,B]
case class Left[A,B](x: A) extends Either[A,B]
case class Right[A,B](x: B) extends Either[A,B]

This is similar to Scala's built-in Either type, but with support for for-comprehensions.

Motivation

Checked exceptions force the developer to wrap their code in try/catch blocks, yielding noisy code that's hard to reason about, hard to combine, and impure in languages where try/catch blocks are not expressions.

Unchecked exceptions allow the developer to forget to wrap their code in try/catch blocks, creating a runtime timebomb.

Mechanics

When encountering an expression that throws an exception, note the types of the exception that it can throw, and the value that it can return. We'll call these A and B, respectively.

Remodel the expression's return type to Either[A,B].

Change any throw e into Left(e).

Change any return b into Right(b).

Example

Exceptions

def div(x: Int, y: Int): Int = x / y
val quotient: Int =
  try {
    div(42, 0)
  } catch {
    case e: ArithmeticException => -999
  }

println(s"quotient: ${quotient}") // quotient: -999

Either

def divE(x: Int, y: Int): Either[Exception, Int] =
  try {
    Right(x / y)
  } catch {
    case e: Exception => Left(e)
  }
val quotientE: Either[Exception,Int] = divE(42, 0)
println(s"quotientE: ${quotientE}") // quotientE: Left(java.lang.ArithmeticException: / by zero)

Right-biased either

implicit class RightBiasedEither[A,B](e: Either[A,B]) {

  def flatMap[C](f: B => Either[A,C]): Either[A,C] =
    e match {
      case Left(a) => Left(a)
      case Right(b) => f(b)
    }

  def map[C](f: B => C): Either[A,C] =
    flatMap { b =>
      Right(f(b))
    }

}
val quotientE2: Either[Exception,Int] =
  for {
    a <- divE(42, 7)
    b <- divE(a, 0)
  } yield b

println(s"quotientE2: ${quotientE2}") // quotientE2: Left(java.lang.ArithmeticException: / by zero)

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Replace exceptions with sum types' |
  scala-cli --scala 2.13 -
quotient: -999
quotientE: Left(java.lang.ArithmeticException: / by zero)
quotientE2: Left(java.lang.ArithmeticException: / by zero)

Replace annotation injection with arguments

Summary

You have fields or parameters that are resolved at run-time via an @Inject annotation.

Eliminate the annotation, and directly set fields and pass parameters.

Motivation

Resolving dependencies at run-time delays the detection of missing or problematic values, increases the cost of fixing them, and broadens the impact of having them. Moving dependency resolution to compile-time saves time, reduces cost, and improves user experience.

Mechanics

Remove uses of @Inject.

Remove run-time dependency resolution configuration.

Add calls at the edge of the system to wire up the application directly.

Example

import com.google.inject.AbstractModule;
import com.google.inject.Guice;
import com.google.inject.Injector;
import com.google.inject.Key;
import com.google.inject.Provides;
import com.google.inject.Inject;

public class Before {

  @interface Greeting {}

  static class EnModule extends AbstractModule {
    @Provides
    @Greeting
    static String getGreeting() {
      return "Hello, world!";
    }
  }

  static class Greeter {
    private final String message;

    @Inject
    Greeter(@Greeting final String message) {
      this.message = message;
    }
  }

  static void run(final Greeter g) {
    System.out.println(g.message);
  }

  public static void main(String[] args) {
    final Injector injector = Guice.createInjector(new EnModule());
    final Greeter greeter = injector.getInstance(Greeter.class);
    run(greeter);
  }
}
public class After {

  @interface Greeting {}

  static class EnModule {
    final String getGreeting() {
      return "Hello, world!";
    }
  }

  static class Greeter {
    private final String message;

    Greeter(final String message) {
      this.message = message;
    }
  }

  static void run(final Greeter g) {
    System.out.println(g.message);
  }

  public static void main(String[] args) {
    final Greeter greeter = new Greeter(new EnModule().getGreeting());
    run(greeter);
  }
}

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown java --section '## Replace annotation injection with arguments' |
  scala-cli --dep com.google.inject:guice:7.0.0 -

Replace pass-through arguments with the reader monad

Summary

You have multiple functions that reuse an argument, either directly or to pass along to another function call.

Remove the argument from functions that don't need it, and remodel functions that do need it as partially-curried on that argument.

Motivation

When dependencies are injected at the top level, they have to be passed around from one function to another, regardless of whether or not any one particular function actually uses them.

The reader monad lets us encode dependencies directly into the type, while providing composability.

Mechanics

We use the following form for our reader monad:

case class Reader[-E, +A](run: E => A):

  def map[E2 <: E, B](f: A => B): Reader[E2, B] =
    Reader(e => f(run(e)))

  def flatMap[E2 <: E, B](f: A => Reader[E2, B]): Reader[E2, B] =
    Reader(e => f(run(e)).run(e))

When encountering functions that pass an argument around to each other, extract the common dependency, convert the blocks that need it to readers of that dependency, and put them back together using map and flatMap.

Example

val PI = 3.14

def area(pi: Double, r: Int): Double = pi * r * r
def volume(pi: Double, r: Int, h: Int): Double =
  val a = area(pi, r)
  val v = a * h
  v
val vol = volume(PI, 6, 7)
println(s"vol: ${vol}") // vol: 791.28

With a reader, we can extract Pi as a dependency:

def areaR(r: Int): Reader[Double, Double] =
  Reader(pi => pi * r * r)

def volumeR(r: Int, h: Int): Reader[Double, Double] =
  areaR(r).map(a => a * h)
val r = 6
val h = 7

val volR = volumeR(r, h) run PI
println(s"volR: ${volR}") // volR: 791.28

The volumeR reader still has an unnecessary dependency on the radius argument, which we can extract as well:

def volumeRR(h: Int): Reader[Double, Reader[Double, Double]] =
  Reader(a => areaR(r).map(a => a * h))
val volRR = volumeRR(h) run r run PI
println(s"volRR: ${volRR}") // volRR: 791.28

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Replace pass-through arguments with the reader monad' |
  scala-cli --scala 3.3 -
vol: 791.28
volR: 791.28
volRR: 791.28

Replace dependency injection with the reader monad

Summary

You have functions with different dependencies, and want to compose them.

Remove dependencies from function arguments, and remodel the functions as partially-curried on those dependencies.

Motivation

The reader monad lets us encode dependencies directly into the type, while providing composability.

Mechanics

A reader encapsulates a function, making it composable with other readers.

case class Reader[-E, +A](run: E => A):

  def map[E2 <: E, B](f: A => B): Reader[E2, B] =
    Reader(e => f(run(e)))

  def flatMap[E2 <: E, B](f: A => Reader[E2, B]): Reader[E2, B] =
    Reader(e => f(run(e)).run(e))

Example

case class Folk(id: Int, name: String)

We define a bunch of database functions, each of which needs a database connection value.

trait HasConnection:
  def c: java.sql.Connection

def initDb(e: HasConnection): Unit =
  val s1 =
    e.c.prepareStatement(
      """|CREATE TABLE IF NOT EXISTS FOLKS (
         |  ID INT NOT NULL,
         |  NAME VARCHAR(1024),
         |  PRIMARY KEY (ID)
         |)""".stripMargin
    )
  s1.execute()
  s1.close()

def addFolk(f: Folk)(e: HasConnection): Unit =
  val s2 =
    e.c.prepareStatement(
      """|INSERT INTO FOLKS (ID, NAME)
         |VALUES (?, ?)""".stripMargin
    )
  s2.setInt(1, f.id)
  s2.setString(2, f.name)
  s2.execute()
  s2.close()

def getFolk(id: Int)(e: HasConnection): Option[Folk] =
  val s3 =
    e.c.prepareStatement(
      "SELECT ID, NAME FROM FOLKS WHERE ID = ?"
    )
  s3.setInt(1, id)
  val rs3 = s3.executeQuery()
  val folk =
    if (rs3.next()) then
      Some(Folk(rs3.getInt("ID"), rs3.getString("NAME")))
    else None
  s3.close()
  folk

def showFolk(f: Folk): String =
  s"Folk ${f.id}: ${f.name}"

def close(e: HasConnection): Unit =
  e.c.close()

We also define a generic way to print lines of text. In practice this could go to stdout, or a log file, etc.

trait HasPrintLine:
  def printLine(x: String): Unit

To put it together, we need a database connection value which we pass around to each database function.

def demo(e: HasConnection with HasPrintLine): Unit =
  initDb(e)
  addFolk(Folk(1, "Folky McFolkface"))(e)
  val fO = getFolk(1)(e)
  val sO = fO.map(showFolk)
  sO.foreach(e.printLine)
  close(e)

def e(db: String): HasConnection with HasPrintLine =
  new HasConnection with HasPrintLine:

    Class.forName("org.hsqldb.jdbcDriver")

    override val c: java.sql.Connection =
      java.sql.DriverManager
        .getConnection(s"jdbc:hsqldb:mem:${db}", "sa", "")

    override def printLine(x: String): Unit = println(x)

demo(e("demo"))

We can convert the imperative database functions to readers to hide the database connection values.

val initDbR: Reader[HasConnection, Unit] =
  Reader(initDb(_))

def addFolkR(f: Folk): Reader[HasConnection, Unit] =
  Reader(addFolk(f))

def getFolkR(id: Int): Reader[HasConnection, Option[Folk]] =
  Reader(getFolk(id))

def showFolkR(id: Int): Reader[HasConnection, Option[String]] =
  for
    fO <- getFolkR(id)
    sO = for
      f <- fO
    yield showFolk(f)
  yield sO

val closeR: Reader[HasConnection, Unit] =
  Reader(close)
def printLineR(x: String): Reader[HasPrintLine, Unit] =
  Reader(_.printLine(x))

Converting our demo from before, we never have to directly deal with a database connection value.

val demoR: Reader[HasConnection with HasPrintLine, Unit] =
  for
    _  <- initDbR
    _  <- addFolkR(Folk(1, "Folky McFolkface"))
    fO <- getFolkR(1)
    sO  = fO.map(showFolk)
    _  <- sO match {
            case None => Reader(_ => ())
            case Some(s) => printLineR(s)
          }
    _   = closeR
  yield ()

demoR.run(e("demoR"))

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Replace dependency injection with the reader monad' |
  scala-cli --scala 3.3 --dep org.hsqldb:hsqldb:2.7.2 -
Folk 1: Folky McFolkface
Folk 1: Folky McFolkface

Differentiate values with newtype

Summary

You have multiple references, often function arguments, that share a common data type.

Make the types of each reference distinct using newtype.

Strong type, weak meaning

One of my favorite advantages of static type systems is how safe they make the process of refactoring. When I change a function's type signature, the compiler lets me know, before I ever try to run my code, everywhere I need to update my program to work with the change.

Consider the following refactoring; I want setName, which currently takes arguments for first and last names, to take just a single argument for a full name.

def setName(first: String, last: String): Unit

                    ||
                    ||
                    \/

def setName(fullName: String): Unit

This change will ripple compile errors through my program, showing me each place I need to update my calls to it. Unfortunately, refactoring code does not always imply a type signature change, and API errors can sneak in to my code under the compiler's radar.

Consider the following refactoring; I want to swap the order of two arguments of the same type.

def search(haystack: String, needle: String): Unit

                    ||
                    ||
                    \/

def search(needle: String, haystack: String): Unit

This breaks any uses I have of the search function, but the compiler isn't able to let me know about it because it can't distinguish needle from haystack; they're both Strings. The search function is "stringly typed".

One way around this is to make discrete record types for needle and haystack, so that this refactoring causes a change in the type signature of search. Unfortunately, this single-value "boxing" of data leads to an ever-growing pile of new data types, each of which imposes additional CPU and memory overhead.

Newtype

What I want is a way to distinguish between these values at compile time, but avoid extra computational burden at runtime. It turns out that several languages support this idea, commonly known as "newtype".

Examples

Newtype in Go via type

Go supports type identities via the type keyword.

search.go:

package main

import "fmt"
import "strings"

type Needle string
type Haystack string

func search(n Needle, h Haystack) bool {
  return strings.Contains(string(h), string(n))
}

func main() {

  const n Needle = "needle"
  const h Haystack = "This haystack is nothing but needles!"

  fmt.Println(search(n, h))
}
$ go run search.go
true

Newtype in Haskell via newtype

Haskell supports newtype via the newtype keyword.

search.hs:

import Data.List (isSubsequenceOf)

newtype Needle = Needle String
newtype Haystack = Haystack String

search :: Needle -> Haystack -> Bool
search (Needle n) (Haystack h) = isSubsequenceOf n h

main :: IO ()
main = do
  let n = Needle "needle"
  let h = Haystack "This haystack is nothing but needles!"
  print $ search n h
$ runhaskell search.hs
True

Roll your own newtype in Scala

Scala does not natively support newtype, but a basic approximation can be written in a few lines using phantom types.

Search.scala:

object Search {

  trait Needle
  trait Haystack

  def apply(n: String with Needle, h: String with Haystack): Boolean =
    h contains n
}

implicit class Tagged[A](val a: A) extends AnyVal {
  def tag[B]: A with B = a.asInstanceOf[A with B]
}

object Main extends App {

  import Search.Needle
  import Search.Haystack

  val n: String with Needle = "needle".tag[Needle]
  val h = "This haystack is nothing but needles!".tag[Haystack]

  println(Search(n, h))
}
$ scala Search.scala
true

Support for newtype in Scala is also available in libraries like Scalaz and Shapeless as tagged types.

Newtype in Scala via the newtype library

NewType is a library that adds a @newtype macro to express newtype:

@newtype case class WidgetId(toInt: Int)

Newtype in Scala 3 via opaque types

Scala 3 supports newtype natively via opaque types.

Tagged types in Scalaz

Scalaz adds support for unboxed tagged types via the scalaz.Tag module.

build.sbt:

libraryDependencies += "org.scalaz" %% "scalaz-core" % "7.3.8"

search.scala:

import scalaz._
import Tag._

trait Needle
trait Haystack

object Search {
  def apply(n: String @@ Needle, h: String @@ Haystack): Boolean =
    unwrap(h) contains unwrap(n)
}

val n = Tag[String,Needle]("needle")
val h = Tag[String,Haystack]("This haystack is nothing but needles!")

println(Search(n, h))
$ curl https://earldouglas.com/itof.md |
  codedown scala --section '#### Tagged types in Scalaz' |
  scala-cli --scala 2.12 --dep org.scalaz::scalaz-core:7.3.8 -
true

Tagged types in Shapeless

Shapeless adds support for type tagging via the shapeless.tag module.

build.sbt:

libraryDependencies ++= "com.chuusai" %% "shapeless" % "2.3.10"

search.scala:

import shapeless.tag
import shapeless.tag.@@

trait Needle
trait Haystack

object Search {
  def apply(n: String @@ Needle, h: String @@ Haystack): Boolean =
    h contains n
}

val n = tag[Needle][String]("needle")
val h = tag[Haystack][String]("This haystack is nothing but needles!")

println(Search(n, h))
$ curl https://earldouglas.com/itof.md |
  codedown scala --section '#### Tagged types in Shapeless' |
  scala-cli --scala 2.13 --dep com.chuusai::shapeless:2.3.10 -

Newtype in Java

Java very does not natively support newtype, but we can adapt our Scala approach to use a boxed tagger.

Search.java:

class Tagged<A, B> {
  public final A value;
  public Tagged(final A value) {
    this.value = value;
  }
  public static <A, B> Tagged<A, B> tag(final A value) {
    return new Tagged<>(value);
  }
}

public class Search {

  interface Needle { }
  interface Haystack { }

  static boolean search(
    final Tagged<String, Needle> needle,
    final Tagged<String, Haystack> haystack) {
    return haystack.value.indexOf(needle.value) != -1;
  }

  public static void main(String[] args) {

    final Tagged<String, Needle> n =
      Tagged.tag("needle");

    final Tagged<String, Haystack> h =
      Tagged.tag("This haystack is nothing but needles!");

    System.out.println(search(n, h));
  }
}
$ scala-cli Search.java
true

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '#### Newtype in Scala' |
  scala-cli --scala 2.13 _
true

Replace code with documentation

Summary

You have source code with various in-line comments, comment blocks, and accompanying forms of documentation.

Make the documentation first class alongside the source code by restructuring them as a unified product.

Literate programming

Literate programming is not strictly a functional pattern. In practice it has been popularized by Haskell, so we'll let it slide and call it functionally-inspired.

Self-coding documentation

Consider an imaginary article about approaches to approximating Pi using recursive summation in Scala:


## Approximating Pi through recursive summation

### Abstract

This is an imaginary article about approaches to estimating Pi using
recursive summation in Scala.

### Background

Pi can be represented as the sum of an infinite series:

```
              k
       ∞  (-1)
π = 4  ∑  ------
      k=0 2k + 1
```

### Initial implementation

Let's write a recursive function that can be used to estimate this sum.

First, we need a way to compute each kth term:

```scala
def term(k: Int): Double = {
  val n = 4 * math.pow(-1, k)
  val d = 2 * k + 1
  n / d.toDouble
}
```

Next, we need a way to sum them:

```scala
def pi(k: Int): Double =
  if (k < 0) 0
  else term(k) + pi(k - 1)
```

Let's try it out:

```scala
println(pi(1))    // 2.666666666666667
println(pi(10))   // 3.232315809405594
println(pi(100))  // 3.1514934010709914
println(pi(1000)) // 3.1425916543395442
```

That's close, but not close enough.  Unfortunately, this implementation
can't do much better:

```
println(pi(10000)) // Throws java.lang.StackOverflowError
```

Attempting to run merely ten thousand iterations overflows the stack!

### Improved implementation

To avoid blowing the stack, we need a tail-recursive version of this
function:

```scala
@scala.annotation.tailrec
def piTR(k: Int, acc: Double = 0): Double =
  if (k < 0) acc
  else piTR(k - 1, term(k) + acc)
```

Let's put it to the test:

```scala
println(piTR(100000000)) // 3.141592663589793
```

Computing one hundred million terms may take a while, but it since
`piTR` is tail-recursive, it does eventually finish.

Using literate programming tools such as Codedown and sbt-lit, we can treat this not as an article but as our source code, and keep code and documentation inextricably linked.

Try it out

This section is literate Scala, and can be run using Codedown.

For presentation purposes, the source uses Scala code blocks wrapped in Markdown code blocks. To get to the Scala, we first need to first extract the Markdown with an extra codedown markdown invocation:

$ curl https://earldouglas.com/itof.md |
  codedown markdown --section '## Replace code with documentation' |
  codedown scala |
  scala-cli --scala 2.13 -
2.666666666666667
3.232315809405594
3.1514934010709914
3.1425916543395442
3.141592663589793

Wrap side effects with IO

Summary

You have imperative code that breaks referential transparency.

Wrap the imperative code with IO, which can be referenced safely without, and save the side effects for the last, outermost layer of your program.

Imperative

def multiplyM(x: Int, y: Int): Int = {
  val z: Int = x * y
  println(s"${x} * ${y} = ${z}")
  z
}
val zM: Int = {
  multiplyM(6, 7) // prints "6 * 7 = 42"
  multiplyM(6, 7) // prints "6 * 7 = 42"
  multiplyM(6, 7) // prints "6 * 7 = 42"
} // zM = 42

IO

class IO[A](a: => A) {

  def unsafePerformIO(): A = a

  def map[B](f: A => B): IO[B] =
    IO(f(a))

  def flatMap[B](f: A => IO[B]): IO[B] =
    IO(f(a).unsafePerformIO())
}

object IO {
  def apply[A](a: => A): IO[A] =
    new IO(a)
}

Pure

def multiplyI(x: Int, y: Int): IO[Int] =
  IO[Int] {
    val z: Int = x * y
    println(s"${x} * ${y} = ${z}")
    z
  }
val zI: IO[Int] = {
  multiplyI(6, 7) // does not print anything
  multiplyI(6, 7) // does not print anything
  multiplyI(6, 7) // does not print anything
} // zI = IO[=> 42]

zI.unsafePerformIO() // prints "6 * 7 = 42"

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Wrap side effects with IO' |
  codedown scala |
  scala-cli --scala 2.13 -
6 * 7 = 42
6 * 7 = 42
6 * 7 = 42
6 * 7 = 42

Use the writer monad for logging

Imperative logging

def log(x: String): Unit =
  println(x)

def add(x: Int, y: Int): Int = {
  log(s"Adding ${x} and ${y}...")
  val z = x + y
  log(s"Result: ${z}")
  z
}

def multiply(x: Int, y: Int): Int = {
  log(s"Multiplying ${x} and ${y}...")
  val z = x * y
  log(s"Result: ${z}")
  z
}
def demo(): Unit = {
  multiply(add(1, 5), 7)
}

demo()

Writer

case class Writer[A](value: A, log: List[String]) {
  def map[B](f: A => B): Writer[B] =
    Writer(f(value), log)
  def flatMap[B](f: A => Writer[B]): Writer[B] = {
    val wb = f(value)
    Writer(wb.value, log ++ wb.log)
  }
}

Functional logging

def logW(x: String): Writer[Unit] =
  Writer((), List(x))

def addW(x: Int, y: Int): Writer[Int] =
  for {
    _ <- logW(s"Adding ${x} and ${y}...")
    z  = x + y
    _ <- logW(s"Result: ${z}")
  } yield z

def multiplyW(x: Int, y: Int): Writer[Int] =
  for {
    _ <- logW(s"Multiplying ${x} and ${y}...")
    z  = x * y
    _ <- logW(s"Result: ${z}")
  } yield z
def demoW(): Writer[Unit] =
  for {
    z1 <- addW(1, 5)
    z2 <- multiplyW(z1, 7)
  } yield ()

println()
demoW().log.map(println)

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Use the writer monad for logging' |
  scala-cli --scala 2.13 -
Adding 1 and 5...
Result: 6
Multiplying 6 and 7...
Result: 42

Adding 1 and 5...
Result: 6
Multiplying 6 and 7...
Result: 42

Wrap async callbacks with futures

Callbacks

def exec[A](x: () => A, k: A => Unit): Unit = {
  val t =
    new Thread() {
      override def run() {
        val a = x()
        k(a)
      }
    }
  t.start
}
def addSlowly(x: Int, y: Int): Int = {
  Thread.sleep(100)
  x + y
}

val start = System.currentTimeMillis
exec(
  () => (addSlowly(1,5), addSlowly(3,4)),
  { xy: (Int, Int) =>
    val z = xy._1 * xy._2
    val stop = System.currentTimeMillis
    println(s"result 1: ${z}, runtime: ${stop - start} ms")
  }
)

Future

import scala.concurrent.Future
import scala.concurrent.ExecutionContext
import scala.concurrent.duration.Duration
import scala.concurrent.duration.SECONDS

implicit val ec = ExecutionContext.global
val startF = System.currentTimeMillis

val f1 = Future(addSlowly(1,5))
val f2 = Future(addSlowly(3,4))

val f3 = f1.zip(f2) map { case (x, y) => x * y }

f3 onSuccess {
  case z =>
    val stopF = System.currentTimeMillis
    println(s"result 2: ${z}, runtime: ${stopF - startF} ms")
}
Thread.sleep(300)

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Wrap async callbacks with futures' |
  scala-cli --scala 2.12 -
result 2: 42, runtime: 432 ms
result 1: 42, runtime: 613 ms

Wrap runtime exceptions in Try

Exceptions

def parseInt(x: String): Int =
  x.toInt

Try

import scala.util.Try
def parseIntT(x: String): Try[Int] =
  Try(x.toInt)

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Wrap runtime exceptions in `Try`' |
  scala-cli --scala 2.12 -

Replace getters and setters with lenses

case class Lens[A,B](get: A => B, set: (A,B) => A) {
  def andThen[C](next: Lens[B,C]) =
    Lens[A,C](
      c => next.get(get(c)),
      (c, b) => set(c, next.set(get(c), b))
    )
}
case class Contact(email: String, phone: String)
case class Folk(id: String, name: String, contact: Contact)
val contactEmailLens = Lens(
  get = (_: Contact).email,
  set = { (x: Contact, email: String) => x.copy(email = email) }
)

val folkContactLens = Lens(
  get = (_: Folk).contact,
  set = (f: Folk, contact: Contact) => f.copy(contact = contact)
)

val folkEmailLens = folkContactLens andThen contactEmailLens
def demoL(): Unit = {

  val f1: Folk =
    Folk(
      "e9fcb4bd-c821-47c9-bc56-f997c361c1e2",
      "Folky McFolkface",
      Contact(
        "folky@mcfolkface",
        "212-555-4240"
      )
    )
  val f2: Folk = folkEmailLens.set(f1, "mcfolkface@folky")

  println(folkEmailLens.get(f1))
  println(folkEmailLens.get(f2))
}

demoL()

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Replace getters and setters with lenses' |
  scala-cli --scala 2.12 -

Convert loops into recursive calls

Loops

def fib(n: Int): Int = {
  var last   = 0
  var result = 1
  for {
    _ <- 2 to n
  } yield {
    var tmp = last
    last = result
    result = result + tmp
  }
  result
}

val fib12: Int = fib(12)
println(s"fib12: ${fib12}")

Recursion

def fibR(n: Int): Int =
  n match {
    case 0 => 0
    case 1 => 1
    case _ => fibR(n - 1) + fibR(n - 2)
  }

val fibR12: Int = fibR(12)
println(s"fibR12: ${fibR12}")

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Convert loops into recursive calls' |
  scala-cli --scala 2.12 -
fib12: 144
fibR12: 144

Replace mutable fields with defensive copy

Mutable fields

class Folk(var id: String, var name: String) {

  override def toString(): String =
    s"""|{
        |   "_id": "${System.identityHashCode(this)}",
        |    "id": "${id}",
        |  "name": "${name}"
        |}""".stripMargin
}
def demo(): Unit = {
  val f1 = new Folk("e9fcb4bd-c821-47c9-bc56-f997c361c1e2", "Folky McFolkface")
  val f2 = f1
  f2.id = "6ca07b55-f19c-4a01-b82e-05e44efc905b"

  println(f1.toString())
  println(f2.toString())
}

demo()

Immutable fields

class FolkC(id: String, name: String) {

  def copy(id: String = id, name: String = name): FolkC =
    new FolkC(id = id, name = name)

  override def toString(): String =
    s"""|{
        |   "_id": "${System.identityHashCode(this)}",
        |    "id": "${id}",
        |  "name": "${name}"
        |}""".stripMargin
}
def demoC(): Unit = {
  val f1 = new FolkC("e9fcb4bd-c821-47c9-bc56-f997c361c1e2", "Folky McFolkface")
  val f2 = f1.copy(id = "6ca07b55-f19c-4a01-b82e-05e44efc905b")

  println(f1.toString())
  println(f2.toString())
}

demoC()

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Replace mutable fields with defensive copy' |
  scala-cli --scala 2.13 -
{
   "_id": "100555887",
    "id": "6ca07b55-f19c-4a01-b82e-05e44efc905b",
  "name": "Folky McFolkface"
}
{
   "_id": "100555887",
    "id": "6ca07b55-f19c-4a01-b82e-05e44efc905b",
  "name": "Folky McFolkface"
}
{
   "_id": "866191240",
    "id": "e9fcb4bd-c821-47c9-bc56-f997c361c1e2",
  "name": "Folky McFolkface"
}
{
   "_id": "1879492184",
    "id": "6ca07b55-f19c-4a01-b82e-05e44efc905b",
  "name": "Folky McFolkface"
}

Replace function with data structure

Functionalized

def filter[A](k: A => Boolean)(xs: List[A]): List[A] =
  xs match {
    case Nil => Nil
    case h::t => if (k(h)) h :: filter(k)(t) else filter(k)(t)
  }
val isEven: Int => Boolean =
  x => x % 2 == 0

println(filter(isEven)((1 to 10).toList)) // List(2, 4, 6, 8, 10)

Defunctionalized

trait Filter[A] {
  def apply(x: A): Boolean
}

case object IsEven extends Filter[Int] {
  def apply(x: Int): Boolean = isEven(x)
}

def filterD[A](f: Filter[A])(xs: List[A]): List[A] =
  xs match {
    case Nil => Nil
    case h::t => if (f(h)) h :: filterD(f)(t) else filterD(f)(t)
  }
println(filterD(IsEven)((1 to 10).toList)) // List(2, 4, 6, 8, 10)

References

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Defunctionalization' |
  scala-cli --scala 2.13 -
List(2, 4, 6, 8, 10)
List(2, 4, 6, 8, 10)

Replace goto with shift and reset

// Workaround for https://github.com/VirtusLab/scala-cli/issues/2653
//> using option -P:continuations:enable
import scala.util.continuations.shift
import scala.util.continuations.shiftUnit0
import scala.util.continuations.reset
import scala.util.continuations.cpsParam
object go {
  def to(l: Label) = shift { k: (Unit => Unit) =>
    l.k(l)
  }
}

case class Label(k: Label => Unit)

def label = shift { k: (Label => Unit) =>
  k(Label(k))
}

def continue = shift { k: (Unit => Unit) =>
  k()
}
reset {

  var i = 0

  val hello = label
  println("Hello, world!")

  val hola = label
  println("ÂĄHola, mundo!")

  i = i + 1

  if (i < 2) go to hello
  else if (i < 4) go to hola
  else continue

  val goodbye = label
  println("Goodbye!")
}

Try it out

This section is literate Scala, and can be run using Codedown:

To run with scala-cli:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Replace goto with shift and reset' |
  scala-cli \
    --scala 2.12.2 \
    --compiler-plugin org.scala-lang.plugins:::scala-continuations-plugin:1.0.3 \
    --dependency org.scala-lang.plugins::scala-continuations-library:1.0.3 \
    -P:continuations:enable \
    -
Hello, world!
ÂĄHola, mundo!
Hello, world!
ÂĄHola, mundo!
ÂĄHola, mundo!
ÂĄHola, mundo!
Goodbye!

To run with sbt:

$ cat << EOF > goto.sc
/***
scalaVersion := "2.12.2"
autoCompilerPlugins := true
addCompilerPlugin("org.scala-lang.plugins" % "scala-continuations-plugin_2.12.2" % "1.0.3")
libraryDependencies += "org.scala-lang.plugins" %% "scala-continuations-library" % "1.0.3"
scalacOptions += "-P:continuations:enable"
*/
EOF
$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Replace goto with shift and reset' >> goto.sc
$ sbt -Dsbt.main.class=sbt.ScriptMain goto.sc
Hello, world!
ÂĄHola, mundo!
Hello, world!
ÂĄHola, mundo!
ÂĄHola, mundo!
ÂĄHola, mundo!
Goodbye!

Wrap nullable values in Option

import java.lang.Integer

Null

def parseInt(x: String): Integer =
  try {
    x.toInt
  } catch {
    case e: NumberFormatException =>
      null
  }

Option

def parseIntO(x: String): Option[Integer] =
  try {
    Some(x.toInt)
  } catch {
    case e: NumberFormatException =>
      None
  }

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Wrap nullable values in `Option`' |
  scala-cli --scala 2.12 -

Replace inheritance with type classes

Inheritance

trait Named {
  def name: String
}

def showNamed(x: Named): Unit =
  println(x.name)
case class Folk(name: String) extends Named
case class Company(name: String) extends Named
showNamed(Folk("Folky McFolkface"))
showNamed(Company("Company McCompanyface"))

Type class

trait Name[A] {
  def name(x: A): String
}

def showName[A:Name](x: A): Unit =
  println(implicitly[Name[A]].name(x))
object Folk {
  implicit val name: Name[Folk] =
    new Name[Folk] {
      def name(x: Folk) = x.name
    }
}

object Company {
  implicit val name: Name[Company] =
    new Name[Company] {
      def name(x: Company) = x.name
    }
}
showName(Folk("Folky McFolkface"))
showName(Company("Company McCompanyface"))

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Replace inheritance with type classes' |
  scala-cli --scala 2.12 -
Folky McFolkface
Company McCompanyface
Folky McFolkface
Company McCompanyface

Parse with validation

Perhaps

case class User(name: String, email: String, phone: String)

val name  = ""
val email = "nodomain@"
val phone = "212-555-fone"

val nameR  = """\w+(\s\w+)*""".r
val emailR = """[^@]+@[^@]+""".r
val phoneR = """\d{3}-\d{3}-\d{4}""".r
val userO: Option[User] =
  (name, email, phone) match {
    case (nameR(), emailR(), phoneR()) => Some(User(name, email, phone))
    case _                             => None
  }

Validation

sealed trait Validation[E,A] {
  def map[B](f: A => B): Validation[E,B] =
    this match {
      case Success(a)  => Success(f(a))
      case Failure(es) => Failure(es)
    }
  def ap[B](ff: Validation[E,A => B]): Validation[E,B] =
    this match {
      case Success(a) =>
        ff match {
          case Success(f)  => Success(f(a))
          case Failure(es) => Failure(es)
        }
      case Failure(es) =>
        ff match {
          case Success(f)   => Failure(es)
          case Failure(es2) => Failure(es ++ es2)
        }
    }
}
case class Success[E,A](value: A) extends Validation[E,A]
case class Failure[E,A](errors: List[E]) extends Validation[E,A]

implicit class RegexValidation(r: scala.util.matching.Regex) {
  def validate(x: String): Validation[String,String] =
    x match {
      case r() => Success(x)
      case _   => Failure(List(s""""${x}""""))
    }
}
val userV: Validation[String,User] =
  nameR.validate(name).ap(
    emailR.validate(email).ap(
      phoneR.validate(phone).map(
        (User.apply _).curried
      )
    )
  )

println(s"userV: ${userV}")

Try it out

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown scala --section '## Parse with validation' |
  scala-cli --scala 2.12 -
userV: Failure(List("", "nodomain@", "212-555-fone"))

Progressions of purity

Let's implement a simple multiplication function in different languages, and examine the implications of each type.

JavaScript

function multiply(x, y) {
  console.log('multiplying', x, 'and', y);
  return x * y;
}

Let's find out what its type is.

console.log(typeof multiply);
// function

Welp, that doesn't tell us a whole lot.

Here's what we know from the type:

Here's what we don't know from the type:

Scala

def multiply(x: Int, y: Int): Int = {
  println(s"multiplying ${x} and ${y}")
  x * y
}

Let's find out what type it is.

:t multiply _
// (Int, Int) => Int

This tells us a bit more.

Here's what we know from the type:

Here's what we still don't know from the type:

Haskell

multiply :: Int -> Int -> Int
multiply x y = x * y

We've declared its type to be Int -> Int -> Int. This tells us quite a lot.

Here's what we know from the type:

Here's what we still don't know from the type:

Linear Haskell

multiply :: Int -o Int -o Int
multiply x y = x * y

In addition to the above, here's what we know from the type:

From the type, we don't know how the arguments are used (i.e. whether they are multiplied), but we do know that they are both used, and that they are each used only once.


Parametricity

An introduction to the usefulness of parametric polymorphism.

Thought experiment

Consider the following JavaScript function:

function foo(x) {
  // implementation hidden
}

What can we infer about this function? The short answer is: nothing.

This function looks like takes a single argument, but JavaScript allows a programmer access to additional, undeclared arguments via the arguments object. Even if this function does only consider its declared argument, we don't know what type the argument is expected to be. It could be a number, or an object, or anything in between.

There's no way to know what sort of data structure this function returns, or whether it returns anything at all.

Essentially, we have no way of knowing anything about what this function does, or how we can use it.

Enter types

Using TypeScript, let's tweak this function's declaration and see what happens:

function foo<A>(x: A): A {
  // implementation hidden
}

We've made three small changes:

  1. Declare A as a type variable
  2. Declare A as the type of the function's input argument
  3. Declare A as the type of the function's output value

Now let's see what we can infer about this function.

The function takes exactly one argument of any type we like. That type is represented by the type variable A.

The function returns a value, also of type A. Since the function's input and output types share the same type variable A, its return value has the same type as whatever argument was passed in.

If we call foo(42) with the number 42, then A means number, and the function will return a number.

If we call foo("forty two") with the string forty two, then A means string, and the function will return a string.

The developer writing the implementation can't know anything about what A will be in practice. Is x a number? Maybe it's a string. Perhaps it's an array or an object. It can be any of these, and no matter which it is, the function has to return a value of the same type.

The implementation

It turns out that there's only one possible implementation for a function with this type: identity.

If we ignore bottom cases like null, undefined, and exceptions, the only way to guarantee that we return a value with the same type as the argument is to simply return the argument itself.

function foo<A>(x: A): A {
  return x;
}

This is a property of parametrically polymorphic functions known as parametricity, and is quite useful for API design, library consumption, and generally reasoning about code.

var x: number = foo(42);
console.log(x + ':', typeof x); // 42: number

var y: string = foo("42");
console.log(y + ':', typeof y); // 42: string

Parametricity also applies to more complicated (and more interesting) functions, and allows us to derive useful theorems about them using only their type signatures.

Affordable theorems

Say that r is a function of type:

r : ∀X. X* → X*

In TypeScript, this can be written as:

function r<A>(xs: Array<A>): Array<A> {

An example of a function of this type is reverse:

function reverse<A>(xs: Array<A>): Array<A> {
  var reversed = [];
  for (var i = xs.length - 1; i >= 0; i--) {
    reversed.push(xs[i]);
  }
  return reversed;
}

Say also that a is a total function of type:

a : A → A'

In TypeScript, this can be written for A of number and A' of string as:

function a(x: number): string {

From what we learned in Theorems for free!, we know that:

a* ∘ rA = rA' ∘ a*

We can demonstrate this with Node.js using jsverify to generate the function a and sample input xs:

import jsc = require('jsverify');

function arrEq(xs, ys) {
  var eq = xs.length === ys.length;
  for (var i = 0; i < xs.length && eq; i++) { eq = xs[i] === ys[i]; }
  return eq;
}

function map(f) {
  return function (xs) {
    var ys = [];
    for (var i = 0; i < xs.length; i++) { ys.push(f(xs[i])); }
    return ys;
  };
}

function compose(f, g) { return function (x) { return f(g(x)); }; }

jsc.assert(
  jsc.forall('array number', 'number -> string', function (xs, a) {
    return arrEq( compose(reverse, map(a))(xs)
                , compose(map(a), reverse)(xs)
        );
  })
);

Progressions of concision

Let's take a contrasting look at the complexity of implementing data structures in different languages.

Disjunction

Let's implement a simple sum type, Either, in a few polymorphic languages.

Language Source Lines of code
Java Either.java 75
Scala Either.scala 3
Haskell Either.hs 1

Either.java

class NoValueException extends Exception { }

interface Either<A,B> {

  boolean isLeft();
  boolean isRight();

  public A leftValue() throws NoValueException;
  public B rightValue() throws NoValueException;

}

class Left<A,B> implements Either<A,B> {

  public final A a;

  public Left(A a) {
    this.a = a;
  }

  @Override
  public boolean isLeft() {
    return true;
  }

  @Override
  public boolean isRight() {
    return false;
  }

  @Override
  public A leftValue() throws NoValueException {
    return a;
  }

  @Override
  public B rightValue() throws NoValueException {
    throw new NoValueException();
  }

  @Override
  public boolean equals(Object x) {
    return toString().equals(x.toString());
  }

  @Override
  public int hashCode() {
    return toString().hashCode();
  }

  @Override
  public String toString() {
    return "Left(" + a.toString() + ")";
  }

}

class Right<A,B> implements Either<A,B> {

  public final B b;

  public Right(B b) {
    this.b = b;
  }

  @Override
  public boolean isLeft() {
    return false;
  }

  @Override
  public boolean isRight() {
    return true;
  }

  @Override
  public A leftValue() throws NoValueException {
    throw new NoValueException();
  }

  @Override
  public B rightValue() throws NoValueException {
    return b;
  }

  @Override
  public boolean equals(Object x) {
    return toString().equals(x.toString());
  }

  @Override
  public int hashCode() {
    return toString().hashCode();
  }

  @Override
  public String toString() {
    return "Right(" + b.toString() + ")";
  }

}

Either.scala

sealed trait Either[A,B]
case class Left[A,B](a: A) extends Either[A,B]
case class Right[A,B](b: B) extends Either[A,B]

Either.hs

data Either a b = Left a | Right b deriving (Eq, Show)

Composable disjunction

Either is useful for storing a value of one of two types, but it isn't useful for building up behavior from multiple disjunctions.

Let's make it composable by adding map, ap, and flatMap (a.k.a. fmap, <*>, and >>=), and biasing it to the right.

Language Source Lines of code
Java RightBiasedEither.java 102
Scala RightBiasedEither.scala 15
Haskell RightBiasedEither.hs 9

RightBiasedEither.java

class NoValueException extends Exception { }

interface RightBiasedEither<A,B> {

  boolean isLeft();
  boolean isRight();

  A leftValue() throws NoValueException;
  B rightValue() throws NoValueException;

  <C> RightBiasedEither<A,C> map(Function<B,C> f);
  <C> RightBiasedEither<A,C> ap(RightBiasedEither<A,Function<B,C>> fe);
  <C> RightBiasedEither<A,C> flatMap(Function<B,RightBiasedEither<A,C>> f);

}

class Left<A,B> implements RightBiasedEither<A,B> {

  public final A a;

  public Left(A a) {
    this.a = a;
  }

  @Override
  public boolean isLeft() {
    return true;
  }

  @Override
  public boolean isRight() {
    return false;
  }

  @Override
  public A leftValue() throws NoValueException {
    return a;
  }

  @Override
  public B rightValue() throws NoValueException {
    throw new NoValueException();
  }

  @Override
  public boolean equals(Object x) {
    return toString().equals(x.toString());
  }

  @Override
  public int hashCode() {
    return toString().hashCode();
  }

  @Override
  public String toString() {
    return "Left(" + a.toString() + ")";
  }

  @Override
  public <C> RightBiasedEither<A,C> map(Function<B,C> f) {
    return new Left<>(a);
  }

  @Override
  public <C> RightBiasedEither<A,C> ap(RightBiasedEither<A,Function<B,C>> fe) {
    return new Left<>(a);
  }

  @Override
  public <C> RightBiasedEither<A,C> flatMap(Function<B,RightBiasedEither<A,C>> f) {
    return new Left<>(a);
  }

}

class Right<A,B> implements RightBiasedEither<A,B> {

  public final B b;

  public Right(B b) {
    this.b = b;
  }

  @Override
  public boolean isLeft() {
    return false;
  }

  @Override
  public boolean isRight() {
    return true;
  }

  @Override
  public A leftValue() throws NoValueException {
    throw new NoValueException();
  }

  @Override
  public B rightValue() throws NoValueException {
    return b;
  }

  @Override
  public boolean equals(Object x) {
    return toString().equals(x.toString());
  }

  @Override
  public int hashCode() {
    return toString().hashCode();
  }

  @Override
  public String toString() {
    return "Right(" + b.toString() + ")";
  }

  @Override
  public <C> RightBiasedEither<A,C> map(Function<B,C> f) {
    return new Right<>(f.apply(b));
  }

  @Override
  public <C> RightBiasedEither<A,C> ap(RightBiasedEither<A,Function<B,C>> fe) {
    return fe.map((f) -> f.apply(b));
  }

  @Override
  public <C> RightBiasedEither<A,C> flatMap(Function<B,RightBiasedEither<A,C>> f) {
    return f.apply(b);
  }

}

RightBiasedEither.scala

sealed trait RightBiasedEither[A,B] {
  def map[C](f: B => C): RightBiasedEither[A,C]
  def ap[C](fe: RightBiasedEither[A,B => C]): RightBiasedEither[A,C]
  def flatMap[C](f: B => RightBiasedEither[A,C]): RightBiasedEither[A,C]
}

case class Left[A,B](a: A) extends RightBiasedEither[A,B] {
  def map[C](f: B => C): RightBiasedEither[A,C] = Left(a)
  def ap[C](fe: RightBiasedEither[A,B => C]): RightBiasedEither[A,C] = Left(a)
  def flatMap[C](f: B => RightBiasedEither[A,C]): RightBiasedEither[A,C] = Left(a)
}

case class Right[A,B](b: B) extends RightBiasedEither[A,B] {
  def map[C](f: B => C): RightBiasedEither[A,C] = Right(f(b))
  def ap[C](fe: RightBiasedEither[A,B => C]): RightBiasedEither[A,C] = fe.map(f => f(b))
  def flatMap[C](f: B => RightBiasedEither[A,C]): RightBiasedEither[A,C] = f(b)
}

RightBiasedEither.hs

data RightBiasedEither a b = Left a | Right b deriving (Eq, Show)

instance Functor (RightBiasedEither a) where
  fmap f e = e >>= (\x -> pure (f x))

instance Applicative (RightBiasedEither a) where
  pure b = Right b
  (<*>) fe e = fe >>= (\f -> fmap f e)

instance Monad (RightBiasedEither a) where
  (>>=) (Right b) f = f b
  (>>=) (Left a) _  = Left a

Move dependency injection from run-time to compile-time

Before

import com.google.inject.AbstractModule;
import com.google.inject.Guice;
import com.google.inject.Injector;
import com.google.inject.Key;
import com.google.inject.Provides;
import com.google.inject.Inject;

class Before {

  @interface Greeting {}

  static class EnModule extends AbstractModule {
    @Provides
    @Greeting
    static String getGreeting() {
      return "Hello, world!";
    }
  }

  static class Greeter {
    private final String message;

    @Inject
    Greeter(@Greeting final String message) {
      this.message = message;
    }
  }

  static void run(final Greeter g) {
    System.out.println(g.message);
  }

  public static void run() {
    final Injector injector = Guice.createInjector(new EnModule());
    final Greeter greeter = injector.getInstance(Greeter.class);
    run(greeter);
  }
}

After

class After {

  @interface Greeting {}

  static class EnModule {
    final String getGreeting() {
      return "Hello, world!";
    }
  }

  static class Greeter {
    private final String message;

    Greeter(final String message) {
      this.message = message;
    }
  }

  static void run(final Greeter g) {
    System.out.println(g.message);
  }

  public static void run() {
    final Greeter greeter = new Greeter(new EnModule().getGreeting());
    run(greeter);
  }
}

Try it out

public class Main {
  public static void main(final String[] args) {
    Before.run();
    After.run();
  }
}

This section is literate Scala, and can be run using Codedown:

$ curl https://earldouglas.com/itof.md |
  codedown java --section '## Move dependency injection from run-time to compile-time' |
  scala-cli --dep com.google.inject:guice:7.0.0 _.java
Hello, world!
Hello, world!