The free monad without the boilerplate

Tweet about this on TwitterShare on LinkedInShare on FacebookShare on Google+Share on Reddit

The free monad is really neat for creating DSL as it allows to completely separate the business logic from the implementation.

It leaves a great freedom for the implementation choices and should you change your implementation you just need to rewrite your interpreter without changing any of the business logic.

That’s great! However after playing a little around with the free monad I was a bit skeptical about the amount of boilerplate required to plug everything together. Writing smart constructors and injectors is not the most trivial thing (depending on how your team is familiar with functional programming).

To conclude this serie on the free monad we’ll have a look at some solutions to remove this boilerplate code, especially freasy and freek.

Freasy

Freasy is a nice library that makes writing a free monad a breeze. It just does what it says: scrap the boilerplate out. Everything looks similar to what we’ve seen so far. And it also comes with an IntelliJ plugin to sort out all the compilation errors that show up in the IDE.

Basically all you need to do is declare a sealed trait with a type alias and all the methods that form your DSL. Then by marking your sealed trait with the @free annotation. This annotation triggers a macro that lifts all the abstract methods into the DSL data types.

Enough talking let’s get back to our stock example and see how it looks like with freasy. It all starts with the DSL ADT so let’s start by defining our DSL grammar.

import cats.free._
import freasymonad.cats.free

type LocationId      = String
type ProductId       = String
type PurchaseOrderId = String

@free trait Stock {
  sealed trait Movement[A]
  type MovementF[A] = Free[Movement, A]

  def poContent(poId: PurchaseOrderId): MovementF[Map[ProductId, Int]]
  def createStock(productId: ProductId, locationId: LocationId, quantity: Int): MovementF[Unit]
  def moveStock(productId: ProductId, source: LocationId, destination: LocationId, quantity: Int): MovementF[Unit]
}

@free trait Inventory {
  sealed trait Update[A]
  type InventoryF[A] = Free[Update, A]

  def increment(locationId: LocationId, productId: ProductId, quantity: Int): InventoryF[Unit]
  def decrement(locationId: LocationId, productId: ProductId, quantity: Int): InventoryF[Unit]
}

@free trait Logging {
  sealed trait Log[A]
  type LogF[A] = Free[Log, A]

  def error(message: String): LogF[Unit]
  def info(message: String): LogF[Unit]
}

@free trait Storage {
  sealed trait Query[A]
  type QueryF[A] = Free[Query, A]

  def get[K, V](key: K): QueryF[Option[V]]
  def put[K, V](key: K, value: V): QueryF[Unit]
}

And that’s all for the data types. No more case classes nor inject classes. All this boilerplate is going to be generated by the scala macro. Pretty neat!

Next let’s define the interpreters. There’s a slight difference here as the interpreters no longer are natural transformation instances (remember the ~> sign). It hasn’t disappear either. Let’s look at an exemple first.

val consoleLogger = new Logging.Interp[Id] {
  def error(message: String): Id[Unit] =
    println(s"[ERROR] [${System.currentTimeMillis()}] $message")
  def info(message: String): Id[Unit] =
    println(s"[INFO ] [${System.currentTimeMillis()}] $message")
}

The consoleLogger is an instance of Logging.Interp[Id] and not directly a (Log ~> Id). However you can get this natural transformation with consoleLogger.interpreter. Then all you have to do is implement the abstract method that were defined in the DSL.

Let’s define the other interpreters that we need for our program

val keyValueStore = new Storage.Interp[Id] {
  var store = Map.empty[Any, Any]
  def get[K, V](key: K) = store.get(key).asInstanceOf[Option[V]]
  def put[K, V](key: K, value: V) = store += key -> value
}

import cats.data.Coproduct
import Inventory.ops._

// same old trick: Concatenate the types into a Coproduct
type StorageAndLog[A]  = Coproduct[Storage.Query, Logging.Log, A]
type StorageAndLogF[A] = Free[StorageAndLog, A]

// Like before we need to pass the injects classes implicitly
def inventoryInterpreter(
  implicit logs: Logging.Injects[StorageAndLog], 
  queries: Storage.Injects[StorageAndLog]) = {

  import logs._, queries._
  new  Inventory.Interp[StorageAndLogF] {
    def increment(locationId: LocationId, productId: ProductId, quantity: Int): StorageAndLogF[Unit] =
      for {
        _ <- info(s"Add $quantity $productId at $locationId")
        existing <- get[(LocationId, ProductId), Int]((locationId, productId)).map(_ getOrElse 0)
        _ <- put((locationId, productId), existing + quantity)
      } yield ()

    def decrement(locationId: LocationId, productId: ProductId, quantity: Int): StorageAndLogF[Unit] =
      for {
        _ <- info(s"Remove $quantity $productId from $locationId")
        existing <- get[(LocationId, ProductId), Int]((locationId, productId)).map(_ getOrElse 0)
        _ <- put((locationId, productId), existing - quantity)
      } yield ()
  }
}

And finally our inbound program

import cats.implicits._

type StockApp[A] = Coproduct[Stock.Movement, Logging.Log, A]

def inboundStock(poId: PurchaseOrderId): Free[StockApp, Unit] = {
  // we can either import injectOps._
  // or pass implicit parameters for Injects[StockApp]
  // like we did for the inventoryInterpreter
  import Logging.injectOps._
  import Stock.injectOps._
  for {
    _       <- info[StockApp](s"Inbounding $poId")
    content <- poContent[StockApp](poId)
    _       <- content.toList.traverseU {
      case (productId, quantity) =>
        createStock[StockApp](productId, "dock", quantity)
    }
    _       <- content.toList.traverseU {
      case (productId, quantity) =>
        moveStock[StockApp](productId, "dock", "warehouse", quantity)
    }
  } yield ()
}

In order to be able to run it we need to declare the same “glue” interpreters and we’re ready to go

val storageAndLogInterpreter = keyValueStore.interpreter or consoleLogger.interpreter

val storageAndLogFreeInterpreter = new (StorageAndLogF ~> Id) {
  def apply[A](storageAndLogF: StorageAndLogF[A]): Id[A] = 
    storageAndLogF.foldMap(storageAndLogInterpreter)
}

val inventoryFreeInterpreter = new (InventoryF ~> StorageAndLogF) {
  def apply[A](inventoryF: InventoryF[A]): StorageAndLogF[A] = 
    inventoryInterpreter.run(inventoryF)
    // or inventoryF.foldMap(inventoryInterpreter.interpreter)
}

val stackInterpreter = 
  movementInterpreter.interpreter
    .andThen(inventoryFreeInterpreter)
    .andThen(storageAndLogFreeInterpreter)

// run our program  
inboundStock("po-1").foldMap(stackInterpreter or consoleLogger.interpreter)

This is a nice library that reduces the boilerplate significantly while not feeling any different from the original implementation: all the code looks the same, you compose the interpreters in the exact same way (and you still have to be careful about their combination order).

Freek

If you like freasy but still want more, you’d probably need Freek. Freek is a powerful library that relies on shapeless to deal with the types. As a result you gain much more flexibility as you can combine the interpreters in any order. On the other hand things look a bit more different …

Enough talking, let’s see how it looks like. Freek approach is sort of the “opposite” of Freasy as this time you have to define all the DSL data types (i.e. the case classes).

This is what our DSL looks like with Freek:

import cats.free.Free
import cats.{Id, ~>}
import cats.implicits._
import freek._

type LocationId      = String
type ProductId       = String
type PurchaseOrderId = String

sealed trait Logging[A]
object Logging {
  case class Info(message: String) extends Logging[Unit]
  case class Error(message: String) extends Logging[Unit]
}

sealed trait Storage[A]
object Storage {
  case class Get[K, V](key: K) extends Storage[Option[V]]
  case class Put[K, V](key: K, value: V) extends Storage[Unit]
}

sealed trait Inventory[A]
object Inventory {
  case class Increment(locationId: LocationId, productId: ProductId, quantity: Int) extends Inventory[Unit]
  case class Decrement(locationId: LocationId, productId: ProductId, quantity: Int) extends Inventory[Unit]
}

sealed trait Movement[A]
object Movement {
  case class POContent(poId: PurchaseOrderId) extends Movement[Map[ProductId, Int]]
  
  case class CreateStock(productId: ProductId, locationId: LocationId, quantity: Int) 
    extends Movement[Unit]
  
  case class MoveStock(productId: ProductId, source: LocationId, destination: LocationId, quantity: Int) 
    extends Movement[Unit]
  }

Nice! Just the ADT with no boiler plate! Now let’s have a look at the interpreters which should look familiar.

val consoleLogger = new (Logging ~> Id) {
  import Logging._
  override def apply[A](fa: Logging[A]): Id[A] = {
    val timestamp = System.currentTimeMillis()
    fa match {
      case Info(message)  => println(s"[INFO ] [$timestamp] $message")
      case Error(message) => println(s"[ERROR] [$timestamp] $message")
    }
  }
}

val keyValueStore = new (Storage ~> Id) {
  import Storage._
  var store: Map[Any, Any] = Map()
  override def apply[A](fa: Storage[A]): Id[A] = fa match {
    case Get(key) => store.get(key)
    case Put(key, value) => store += key -> value
  }
}

We’re back with our good old interpreters and the pattern matching in the apply method.

Now let’s define the interpreters for the nested DSLs (remember the layered structure of our program).

// This is the way to combine several DSL into one co-product
// The nice things is that you can combine as many as you need at once
type StorageAndLogPRG  = Storage :|: Logging :|: NilDSL
val  StorageAndLogPRG  = DSL.Make[StorageAndLogPRG]
type StorageAndLogF[A] = Free[StorageAndLogPRG.Cop, A]

val inventoryInterpreter = new (Inventory ~> StorageAndLogF) {
  import Inventory._, Logging._, Storage._
  override def apply[A](fa: Inventory[A]) = fa match {
    case Increment(locationId, productId, quantity) =>
      // we use a for-comprehension to translate our Inventory commands
      // into logs and storage operations but we have to specify a common type
      // for our program using the "strange" .freek[PRG] suffix
      for {
        _        <- Info(s"Add $quantity $productId at $locationId").freek[StorageAndLogPRG]
        maybeOld <- Get[(LocationId, ProductId), Int]((locationId, productId)).freek[StorageAndLogPRG]
        existing = maybeOld getOrElse 0
        _        <- Put((locationId, productId), existing + quantity)
          .upcast[Storage[Unit]]
          .freek[StorageAndLogPRG]
      } yield ()

    case Decrement(locationId, productId, quantity) =>
      for {
        _        <- Info(s"Remove $quantity $productId from $locationId").freek[StorageAndLogPRG]
        maybeOld <- Get[(LocationId, ProductId), Int]((locationId, productId)).freek[StorageAndLogPRG]
        existing = maybeOld getOrElse 0
        _        <- Put((locationId, productId), existing - quantity)
          .upcast[Storage[Unit]]
          .freek[StorageAndLogPRG]
      } yield ()
  }
}

// And same thing for translating from Movement to Inventory
type InventoryPRG  = Inventory :|: NilDSL
val InventoryPRG   = DSL.Make[InventoryPRG]
type InventoryF[A] = Free[InventoryPRG.Cop, A]

val movementInterpreter = new (Movement ~> InventoryF) {
  import Movement._, Inventory._

  override def apply[A](fa: Movement[A]): InventoryF[A] = fa match {
    case POContent(_) =>
      Free.pure(Map("Mars" -> 200, "Milkyway" -> 150, "Galaxy" -> 100))
    case CreateStock(productId, locationId, quantity) =>
      for {
        _ <- Increment(locationId, productId, quantity).freek[InventoryPRG]
      } yield ()
    case MoveStock(productId, source, destination, quantity) =>
      for {
        _ <- Decrement(source, productId, quantity).freek[InventoryPRG]
        _ <- Increment(destination, productId, quantity).freek[InventoryPRG]
      } yield ()
  }
}

And let’s write our main program by combining the types in the same fashion. I like this way of defining the coproducts (Note that’s it’s not the cats Coproduct but an implementation based on shapeless which allows for these combinations of many types at once).

type PRG = Movement :|: Logging :|: NilDSL
val PRG  = DSL.Make[PRG]

def inboundStock(poId: PurchaseOrderId): Free[PRG.Cop, Unit] =
  for {
    _       <- Logging.Info(s"Inbounding $poId").freek[PRG]
    content <- Movement.POContent(poId).freek[PRG]
    _       <- content.toList.traverseU {
      case (productId, quantity) =>
        Movement.CreateStock(productId, "dock", quantity).freek[PRG]
    }
    _       <- content.toList.traverseU {
      case (productId, quantity) =>
        Movement.MoveStock(productId, "dock", "warehouse", quantity).freek[PRG]
    }
  } yield ()

I am not a big fan of these .freek[PRG] but otherwise it looks really good, very similar to our initial program, with less boilerplate and less hassle.

Now let’s create our “glue” interpreters and run our program:

val inventoryFreeInterpreter = new (InventoryF ~> StorageAndLogF) {
  override def apply[A](fa: InventoryF[A]) = 
   fa.interpret(inventoryInterpreter)
}

val storageAndLogFreeInterpreter = new (StorageAndLogF ~> Id) {
  override def apply[A](fa: StorageAndLogF[A]) = 
    fa.interpret(consoleLogger :&: keyValueStore)
}

val stackInterpreter = movementInterpreter
  .andThen(inventoryFreeInterpreter)
  .andThen(storageAndLogFreeInterpreter)

inboundStock("po-1").interpret(consoleLogger :&: stackInterpreter)

Note that we use interpret instead of foldMap. It does the same thing but interpret works even with interpreters not combined in the exact correct order (notice the console logger comes first here whereas in the PRG type Logging comes last).

Freek also provides very useful support for monad transformers (which it calls onion – as you have to “peel” the monad layers to get to the value) but this post is already long enough so I’ll just stop here.

Hopefully this series got you a decent understanding of how the free monad works and even provided you with some hints to make your life easier if you decide to use it in your projects.