State is generally an anathema to functional programmers but necessary to the rest of us. So, how are these forces reconciled? Perhaps not surprisingly it requires monads.
"We'll say that a stateful computation is a function that takes some state and returns a value along with some new state." (from LYAHFGG). So, if you had a random number generator (RNG), for each call you'd receive not just a random number but the generator with a new state for the next call. Calling the same generator repeatedly will give you the same value which makes testing easier.
Scalaz gives you a state out of the box. As an example, let's say we're modelling a stack data structure. The pop would look like this:
import scalaz._
import Scalaz._
val pop = State[Stack, Int] { // (f: S => (S, A)) where S is Stack and A is Int
case x::xs =>
println("Popping...")
(xs, x)
}
Also note "the important thing to note is that unlike the general monads we've seen, State specifically wraps functions." (Learning Scalaz)
Let's now define what a push is:
def push(a: Int) = State[Stack, Unit] {
case xs =>
println(s"Pushing $a")
(a :: xs, ())
}
Because State is a monad, "the powerful part is the fact that we can monadically chain each operations using for syntax without manually passing around the Stack values". That is, we access the contents of State via map and flatMap.
So, let's say we want to pop the value at the top of the stack then push two values onto it, 3 and 4.
val popx1Push2: State[Stack, Int] = for {
a <- pop
_ <- push(3)
_ <- push(4)
} yield a
which does absolutely nothing. What is to be popped? Onto what do we push 3 and 4? We need to run it and then map (or foreach or whatever) over the result to access the values. "Essentially, the reader monad lets us pretend the value is already there." (from Learning Scalaz)
popx1Push2.run(List(0, 1, 2)).foreach { case (state, popped) =>
println(s"New state = $state, popped value = $popped") // New state = List(4, 3, 1, 2), popped value = 0
}
case class State [S, +A](runS: S => (A, S)) {
def map[B](f: A => B) =
State [S, B]( s => {
val (a, s1) = runS(s)
(f(a), s1)
})
def flatMap[B](f: A => State [S, B]) =
State [S, B]( s => {
val (a, s1) = runS(s)
f(a).runS(s1)
})
}
Now it becomes clearer what State actually represents. It is a monad (see map and flatMap) that contains a function that given one state, S, can give you the next plus an accompanying value, A.
So, manipulating the state for our stack example looks like this:
def getState[S]: State [S, S] =
State (s => (s, s))
where the function says: given a state, return that state.
def setState[S](s: S): State [S, Unit] =
State (_ => ((), s))
where the function says: I don't care which state you give me, the next one is s. Oh, and I don't care about the accompanying value.
def pureState[S, A](a: A): State[S, A] =
State(s => (a , s))
Now, we do a stateful computation. It's very simple, it just maps over the state monads to generate a unit of work that adds 1 to whatever is given.
val add1: State[Int, Unit] = for {
n <- getState
b <- setState(n + 1)
} yield (b)
println(add1.runS(7)) // "8"
This is trivial but you can do more complicated operations like:
def zipWithIndex[A](as: List [A]): List [(Int , A)] =
as.foldLeft (
pureState[Int, List[(Int , A)]](List())
)((acc, a) => for {
xs <- acc
n <- getState
_ <- setState(n + 1)
} yield (n , a) :: xs ).runS(0)._1.reverse
Yoiks! What's this doing? Well, reading it de-sugared and pretending we've given it an array of Chars may make it easier:
type IndexedChars = List[(Int, Char)]
val accumulated: State[Int, IndexedChars] = as.foldLeft(
pureState(List[(Int, Char)]()) // "the next state is the last state and the accompanying value is an empty list"
)((acc: State[Int, IndexedChars], a: Char) => {
acc.flatMap { xs: IndexedChars => // xs is the accompanying value that the accumulator wraps
getState.flatMap { n: Int => // "Given a state, return that state." (n is fed in via runS)
setState(n + 1).map(ignored => // "I don't care which state you give me, the next one is n+1"
((n, a) +: xs) // the accompanying value is now the old plus the (n,a) tuple.
)
}
}
})
Which is a bit more understandable.