Saturday, October 22, 2016

State monads

State is generally an anathema to functional programmers but necessary to the rest of us. So, how are these forces reconciled? Perhaps not surprisingly it requires monads.

"We'll say that a stateful computation is a function that takes some state and returns a value along with some new state." (from LYAHFGG). So, if you had a random number generator (RNG), for each call you'd receive not just a random number but the generator with a new state for the next call. Calling the same generator repeatedly will give you the same value which makes testing easier.

Scalaz gives you a state out of the box. As an example, let's say we're modelling a stack data structure. The pop would look like this:

import scalaz._
import Scalaz._

    type Stack = List[Int]

    val pop = State[Stack, Int] { // (f: S => (S, A)) where S is Stack and A is Int
      case x::xs =>
        (xs, x)

Note that pop is a val not a def. Like the RNG mentioned earlier, it returns the popped value and the new state. We'll ignore error conditions like popping an empty stack for now.

Also note "the important thing to note is that unlike the general monads we've seen, State specifically wraps functions." (Learning Scalaz)

Let's now define what a push is:

    def push(a: Int) = State[Stack, Unit] {
      case xs =>
        println(s"Pushing $a")
        (a :: xs, ())

Once more, the function returns the new state and a value - only this time the value isn't actually defined. What value can a push to a stack return that you didn't have in the first place?

Because State is a monad, "the powerful part is the fact that we can monadically chain each operations using for syntax without manually passing around the Stack values". That is, we access the contents of State via map and flatMap.

So, let's say we want to pop the value at the top of the stack then push two values onto it, 3 and 4.

    val popx1Push2: State[Stack, Int] = for {
      a <- pop
      _ <- push(3)
      _ <- push(4)
    } yield a 

which does absolutely nothing. What is to be popped? Onto what do we push 3 and 4? We need to run it and then map (or foreach or whatever) over the result to access the values. "Essentially, the reader monad lets us pretend the value is already there." (from Learning Scalaz), 1, 2)).foreach { case (state, popped) =>
      println(s"New state = $state, popped value = $popped") // New state = List(4, 3, 1, 2), popped value = 0

But we don't need Scalaz. We can roll our own state (with code liberally stolen from here).

case class State [S, +A](runS: S => (A, S)) {

  def map[B](f: A => B) =
    State [S, B]( s => {
      val (a, s1) = runS(s)
      (f(a), s1)
  def flatMap[B](f: A => State [S, B]) =
    State [S, B]( s => {
      val (a, s1) = runS(s)

Now it becomes clearer what State actually represents. It is a monad (see map and flatMap) that contains a function that given one state, S, can give you the next plus an accompanying value, A.

So, manipulating the state for our stack example looks like this:

  def getState[S]: State [S, S] =
    State (s => (s, s))
where the function says: given a state, return that state.

  def setState[S](s: S): State [S, Unit] =
    State (_ => ((), s))

where the function says: I don't care which state you give me, the next one is s. Oh, and I don't care about the accompanying value.

  def pureState[S, A](a: A): State[S, A] =
    State(s => (a , s))

where the function says the next state is the last state and the accompanying value is what I was given. Pure in this sense refers to the construction of a state, sometimes called unit, and called point in Scalaz.

Now, we do a stateful computation. It's very simple, it just maps over the state monads to generate a unit of work that adds 1 to whatever is given.

    val add1: State[Int, Unit] = for {
      n   <- getState
      b   <- setState(n + 1)
    } yield (b)

    println(add1.runS(7)) // "8"

This is trivial but you can do more complicated operations like:

  def zipWithIndex[A](as: List [A]): List [(Int , A)] =
    as.foldLeft (
      pureState[Int, List[(Int , A)]](List())
    )((acc, a) => for {
      xs  <- acc
      n   <- getState
      _   <- setState(n + 1)
    } yield (n , a) :: xs ).runS(0)._1.reverse

Yoiks! What's this doing? Well, reading it de-sugared and pretending we've given it an array of Chars may make it easier:

    type IndexedChars = List[(Int, Char)]

    val accumulated: State[Int, IndexedChars] = as.foldLeft(
      pureState(List[(Int, Char)]())      // "the next state is the last state and the accompanying value is an empty list"
    )((acc: State[Int, IndexedChars], a: Char) => {
      acc.flatMap { xs: IndexedChars =>   // xs is the accompanying value that the accumulator wraps
        getState.flatMap { n: Int =>      // "Given a state, return that state." (n is fed in via runS)
          setState(n + 1).map(ignored =>  // "I don't care which state you give me, the next one is n+1"
            ((n, a) +: xs)                // the accompanying value is now the old plus the (n,a) tuple.

Which is a bit more understandable.

No comments:

Post a Comment