Ad-hoc polymorphism in Standard ML

October 17, 2013

I recently started learning SML, and among the first tools I have needed to reach for were functor and monad instances for some simple SML collection types. This led me to signatures and structures.

Consider the following:

val xs = [1,2,3]
fun plusOne(x) = x + 1

To apply plusOne to each element in xs, we can use List#map from the SML Basis Library:

val xs2 = List.map(plusOne)(xs) (* [2,3,4] *)

The type of List#map is ('a -> 'b) -> 'a list -> 'b list.

But what if we didn't have access to List#map, or if we wanted a way to generalize map to a common interface? We can define such an interface using a signature:

signature Functor =
sig
  type 'a t
  val map: ('a -> 'b) -> 'a t -> 'b t
end

This signature says that a Functor abstracts over some type constructor t to implement a map function that applies a function of type 'a -> 'b for arbitrary types 'a and 'b to whatever 'a means in t to produce a 'b t.

That's a lot of alphabet soup, so let's be more concrete. To implement a signature, we write a structure:

structure ListFunctor:Functor =
struct
  type 'a t = 'a list
  val map = fn f => fn xs => List.map(f)(xs)
end

This structure implements a Functor where t is list. This lets us apply 'a -> 'b to each element in a 'a list, resulting in a 'b list. Behind the scenes we just defer to List#map.

Now we can use it to map plusOne over xs:

val xs2 = List.map(plusOne)(xs) (* [2,3,4] *)

Now let's see if we can extend this and make a monad for lists. Here's our signature:

signature Monad =
sig
  include Functor
  val flatMap: 'a t -> ('a -> 'b t) -> 'b t
end

This signature extends Functor using the include keyword, and adds an additional function definition, flatMap, which is nearly the same as map, except that the function being applied returns a 'b t instead of a flat 'b.

To implement flatMap for lists, we can't delegate directly to any single List function, but we can cheat by combining List#map with List#concat:

structure ListMonad:Monad =
struct
  open ListFunctor
  val flatMap = fn xs => fn f => List.concat(List.map(f)(xs))
end

The type of List#concat is 'a list list -> 'a list.

This structure extends ListFunctor using the open keyword, and implements flatMap by first mapping f over xs, and then flattening the resulting list of lists.

Here's what using flatMap looks like:

fun repeat(x) = [x,x,x]
val xs3 = ListMonad.map(plusOne)(xs) (* [2,3,4] *)
val xs4 = ListMonad.flatMap(xs)(repeat) (* [1,1,1,2,2,2,3,3,3] *)

Thanks to inheritence, both ListFunctor#map as well as ListMonad#flatMap directly on the list monad instance.