Tuesday, January 22, 2019

Chaining Monads


What is exactly going on when we chain monads? Here is some Scalaz code to demonstrate.

First, we create the monad:

import scalaz.Monad

sealed trait MonadX[+A] {
  def run(ctx: Context): A
}

object MonadX {

  def apply[A](f: Context => A): MonadX[A] = new MonadX[A] {
    override def run(ctx: Context): A = f(ctx)
  }

  implicit val monad = new Monad[MonadX] {
    override def bind[A, B](fa: MonadX[A])(f: A ⇒ MonadX[B]): MonadX[B] = 
      MonadX(ctx ⇒ f(fa.run(ctx)).run(ctx))

    override def point[A](a: ⇒ A): MonadX[A] = MonadX(_ ⇒ a)
  }

}

We created two such monads that we will chain:

    case class Context(aString: String, aLong: Long)

    val hello: MonadX[String] = MonadX { ctx: Context =>
      ctx.aString
    }
    val meaningOfLife: MonadX[Long] = MonadX { ctx: Context =>
      ctx.aLong
    }

Not very useful, are they? But you get the idea. Now, all we want is a for-comprehension, so:

  val concatLength: MonadX[Int] = for {
    x <- hello
    y <- meaningOfLife
  } yield (x + y).length

You can think of monads as programs, so let's run it:

  val ctx         = Context("hello", 42)

  val length: Int = concatLength.run(ctx)

Using this highly de-sugared and non-FP code to demonstrate, the flow of control can be given as this:

About to run for-comprehension
==============================
bind: Creating boundHello with fa=Hello, f=<function1>

This is just the first part of our for-comprehension (x <- greeting). Note that nothing further is executed as monads are lazy. All the bind operation did was create a new MonadX containing a function. We never applied that function.

Only when we run the outer monad (concatLength.run(ctx)) does the 'program' execute:

About to run boundHello
=======================
boundHello.run
    boundHello.f(ctx) = 
        Hello.run
            helloFn(ctx) = 
                'hello'
        Hello.run Finished
        <function1>(hello) = 
            bind: Creating boundMeaningOfLife with fa=MeaningOfLife, f=<function1>
            'boundMeaningOfLife'
        boundMeaningOfLife.run
            boundMeaningOfLife.f(ctx) = 
                MeaningOfLife.run
                    meaningOfLifeFn(ctx) = 
                        '42'
                MeaningOfLife.run Finished
                <function1>(42) = 
                    Creating point (7) [Integer]
                    'point'
                point.run
                    point.f(ctx) = 
                        '7'
                point.run Finished
                '7'
        boundMeaningOfLife.run Finished
        '7'
boundHello.run Finished

So, what's happened? Well, first our monads hello and meaningOfLife  have had bind called with them (bind is another word for flatMap in some languages). The reason is that anything in a for-comprehension will have to be flatMapped as that's what we're doing under the covers. Yes,  de-sugared for-comprehension invokes map but a map can be substituted for a flatMap and a point (sometimes called unit, see the monad laws here). And this is where the point comes from in the above flow.

Leveraging this substitution, the Scalaz map, Monad.map is defined as

map[A,B](fa: F[A])(f: A => B): F[B] = bind(fa)(a => point(f(a)))

Since you can't see map and flatMap functions needed by the Scala compiler, where do they come from? They're provided by Scalaz in scalaz.syntax.FunctorOps.map and scalaz.syntax.BindOps.flatMap.

In my heavily de-sugared version of this code, I have given my functions names. But the Scala compiler is also giving me anonymous functions (<function1>) . These appear to be y <- meaningOfLife for the first one and the yield function, (x + y).length, for the second.

So, in a brief, hand-wavey summary: 
  1. the outermost monad is not hello but boundHello which wraps it. 
  2. boundHello calls run on its hello.
  3. It feeds the results from this into its function, f. This happens to be the block of code that is a result of y <- meaningOfLife. Since we're now using the bind/point substitution while mapping, we're given a boundMeaningOfLife.
  4. boundHello runs this boundMeaningOfLife, which, being a recursive structure, runs the same steps as 1 and 2 but on its wrapped MeaningOfLife monad.
  5. Again, like boundHello in step #3, boundMeaningOfLife calls its f function but this time the result is a point.
  6. Again, since it's a recursive structure, run is called upon point which returns the result of the yield function.
  7. Then this program's "stack" is popped all the way to the top with our result.


No comments:

Post a Comment