What is exactly going on when we chain monads? Here is some Scalaz code to demonstrate.
First, we create the monad:
import scalaz.Monad
sealed trait MonadX[+A] {
def run(ctx: Context): A
}
object MonadX {
def apply[A](f: Context => A): MonadX[A] = new MonadX[A] {
override def run(ctx: Context): A = f(ctx)
}
implicit val monad = new Monad[MonadX] {
override def bind[A, B](fa: MonadX[A])(f: A ⇒ MonadX[B]): MonadX[B] =
MonadX(ctx ⇒ f(fa.run(ctx)).run(ctx))
override def point[A](a: ⇒ A): MonadX[A] = MonadX(_ ⇒ a)
}
}
We created two such monads that we will chain:
case class Context(aString: String, aLong: Long)
val hello: MonadX[String] = MonadX { ctx: Context =>
ctx.aString
}
val meaningOfLife: MonadX[Long] = MonadX { ctx: Context =>
ctx.aLong
}
Not very useful, are they? But you get the idea. Now, all we want is a for-comprehension, so:
val concatLength: MonadX[Int] = for {
x <- hello
y <- meaningOfLife
} yield (x + y).length
val ctx = Context("hello", 42)
val length: Int = concatLength.run(ctx)
About to run for-comprehension
==============================
bind: Creating boundHello with fa=Hello, f=<function1>
This is just the first part of our for-comprehension (x <- greeting). Note that nothing further is executed as monads are lazy. All the bind operation did was create a new MonadX containing a function. We never applied that function.
Only when we run the outer monad (concatLength.run(ctx)) does the 'program' execute:
About to run boundHello
=======================
boundHello.run
boundHello.f(ctx) =
Hello.run
helloFn(ctx) =
'hello'
Hello.run Finished
<function1>(hello) =
bind: Creating boundMeaningOfLife with fa=MeaningOfLife, f=<function1>
'boundMeaningOfLife'
boundMeaningOfLife.run
boundMeaningOfLife.f(ctx) =
MeaningOfLife.run
meaningOfLifeFn(ctx) =
'42'
MeaningOfLife.run Finished
<function1>(42) =
Creating point (7) [Integer]
'point'
point.run
point.f(ctx) =
'7'
point.run Finished
'7'
boundMeaningOfLife.run Finished
'7'
boundHello.run Finished
So, what's happened? Well, first our monads hello and meaningOfLife have had bind called with them (bind is another word for flatMap in some languages). The reason is that anything in a for-comprehension will have to be flatMapped as that's what we're doing under the covers. Yes, de-sugared for-comprehension invokes map but a map can be substituted for a flatMap and a point (sometimes called unit, see the monad laws here). And this is where the point comes from in the above flow.
Leveraging this substitution, the Scalaz map, Monad.map is defined as
map[A,B](fa: F[A])(f: A => B): F[B] = bind(fa)(a => point(f(a)))
In my heavily de-sugared version of this code, I have given my functions names. But the Scala compiler is also giving me anonymous functions (<function1>) . These appear to be y <- meaningOfLife for the first one and the yield function, (x + y).length, for the second.
So, in a brief, hand-wavey summary:
- the outermost monad is not hello but boundHello which wraps it.
- boundHello calls run on its hello.
- It feeds the results from this into its function, f. This happens to be the block of code that is a result of y <- meaningOfLife. Since we're now using the bind/point substitution while mapping, we're given a boundMeaningOfLife.
- boundHello runs this boundMeaningOfLife, which, being a recursive structure, runs the same steps as 1 and 2 but on its wrapped MeaningOfLife monad.
- Again, like boundHello in step #3, boundMeaningOfLife calls its f function but this time the result is a point.
- Again, since it's a recursive structure, run is called upon point which returns the result of the yield function.
- Then this program's "stack" is popped all the way to the top with our result.
No comments:
Post a Comment