Writer Monads and Space Leaks

The 'WriterT space leak' is something that gets talked about a lot in Haskell folklore. Here's a short explanation of what that means.

Defining WriterT

The Writer monad represents computations with "extra output": produced data, logging messages, and so forth. That "extra output" has some constraints: we need to know what it means to have an "empty" piece of that data (so we can construct values with no extra output) and we need to know how to combine or append the "extra output" (so we can run two distinct computations and combine them appropriately.) We can get those properties in Haskell by ensuring that the "extra output" parameter has a Monoid constraint.

The WriterT transformer is defined like this:

newtype WriterT w m a =
  WriterT { runWriterT :: m (a, w) }

which is to say: it wraps an underlying monad m, which represents a computation that produces a pair of a (the result of our computation) and w (the "extra output" parameter). The Monad instance (I'm eliding the Functor and Applicative instances) looks like this:

instance (Monoid w, Monad m) => Monad (WriterT w m) where
  -- return the provided value, filling in the "extra output"
  -- with an "mempty" value.
  return x = WriterT $ return (x, mempty)

  -- to chain two WriterT computations together...
  m >>= f = WriterT $ do

    -- A: run the first computation, fetching its result
    -- and extra output
    (x, w)  <- runWriterT m

    -- B: produce the second computation by applying f to the
    -- result from the first computation, and then run it,
    -- fetching its result and extra output
    (y, w') <- runWriterT (f x)

    -- C: return the final result with the pieces of
    -- extra output from each computation combined together.
    return (y, w `mappend` w')

The last thing we need is the tell operation, which is the function we use to create our "extra output":

tell :: (Monoid w, Monad m) => w -> WriterT w m ()
tell o = WriterT $ return ((), o)

A simple contrivted WriterT computation looks like this:

sample :: WriterT (Sum Int) Identity ()
sample = do
  tell (Sum 2)
  tell (Sum 3)
  tell (Sum 4)
  return ()

running this produces a return value of () as well as "extra output" equivalent to Sum 2 <> Sum 3 <> Sum 4 === Sum (2 + 3 + 4) === Sum 9. Sure enough:

>>> runIdentity $ runWriterT sample
((), Sum 9)

In Space Leaks, No-One Can Hear You Scream

Well, what's the problem? Let's look more closely at the definition of >>=. In particular, notice that in order to start evaluating the call to mappend in step C, we need to have the output from both steps A and B. …except step B represents evaluating the entire rest of the computation, including an arbitrarily large number of other calls to >>=.

If we started evaluating the first call to >>= in runWriterT sample, then step A represents evaluating tell (Sum 2), and step B represents evaluating tell (Sum 3) >>= (\_ -> tell (Sum 4) >>= (\_ -> return ())). We get our output value w from the first one, and then we have to hold on to it while running the entire second computation in order to get the second extra output value w' before we can mappend them in step C. This problem appears with every call to >>=: we have to finish the entire subsequent chain of binds before we get the extra output to combine.

This is especially silly for sample, where I've chosen to use Sum Int as the extra output. It's just an accumulator! But in order to actually start adding those numbers together, we have to go through the entire chain of binds and produce each number first, each of which has to be kept around until we get to the final return, and only then can we go back and start adding the numbers we've generated together. In this case, it's not gonna kill our heap, but consider a computation like

largerSample :: WriterT (Sum Int) Identity ()
largerSample = sequence [ tell (Sum 1) | _ <- enumTo 1000000 ]

This computation will contain 1000000 calls to >>=, which means we will generate 1000000 instances of the number 1 first, and only then can it start adding them together. This can't be changed with strictness annotations or modifications to evaluation stategy: we'll need to rearrange the way that data depends on other data.

Alleviating the Space Leak

The problem solution is described in short by Gabriel Gonzalez here and here on the mailing lists, but I'll explain the same thing slightly more verbosely:

What we really want is for the call to mappend to happen earlier, which means moving around the way our intermediate values depend on each other.1 We've written our WriterT so that the "extra output" is always returned from a computation, and never passed to a computation: that means a given computation doesn't "know" about the intermediate value of that value, and we can only run our mappend operation after we've finished it. Let's try writing an alternate WriterishT monad that differs by also taking the "extra output" from previous steps:

newtype WriterishT w m a =
  WriterishT { runWriterishT :: w -> m (a, w) }

The w on the left side of the arrow represents the intermediate value of the "extra output" that comes from previous steps. We can move the mappend step from the implementation of >>= into the tell function, which now combines the new value being "told" with the previous values:

tellish :: (Monoid w, Monad m) => w -> WriterishT w m ()
tellish w = WriterishT $ \w' -> return ((), w <> w')

The implementation of return used to return mempty, but now it's taking an intermediate value it might need to combine with the 'new state', so it would look like this:

return x = WriterishT $ \w -> (x, w <> mempty)

Luckily, the Monoid laws tell us that w <> mempty === w, so we can simplify that definition down a fair amount. This ends up removing all mention of the Monoidal nature of our data in the Monad instance, so we can omit the Monoid w constraint:

instance (Monad m) => Monad (WriterishT w m) where
  -- the motivation for this implementation is explained
  -- up above...
  return x = WriterishT $ \w -> (x, w)

  -- The bind instance looks pretty similar, but again, we
  -- need to take the current state of our accumulator:
  m >>= f = WriterishT $ \w -> do
    -- A: we pass that on to the first computation...
    (x, w')  <- runWriterishT m w
    -- B: and pass that on to the second computation...
    (y, w'') <- runWriterishT (f x) w'
    -- C: and return just the second state, because the
    -- call to 'mappend' was taken care of inside one
    -- of these.
    return (y, w'')

Now, the call to mappend have all the data they need as soon as tellish happens. Step C only relies on the "extra output" from B, which itself only relies on the "extra output" from A. That means that we've already run mappend before step B, and can throw away the older intermediate value.

…this is a little bit dishonest, though: the type that we've made is significantly more powerful than just a WriterT. One feature of WriterT is that there's no way for a computation to peek at or rely on the "extra output" from previous steps. However, if we aren't careful about the interface to our WriterishT monad—in particular, if we don't hide the constructor—a user could write code like this

getCurrentOutput :: WriterishT w m w
getCurrentOutput = WriterishT $ \w -> return (w, w)

poorlyBehaved :: WriterishT (Sum Int) Identity String
poorlyBehaved = do
  Sum n <- getCurrentOutput
  if n == 0
    then return "zero"
    else return "nonzero"

which would be impossible in WriterT! (If you can't see why, it's worth trying to write it and seeing what walls you come up against.) This is because, in point of fact, what I've been calling WriterishT is literally just StateT by a different name! That makes it much more powerful than WriterT, but accordingly it removes some of the guarantees you get by using a WriterT, as well.

So in summary, the WriterishT or StateT can alleviate the space-leak problems with WriterT, but at the cost of some of the desirable properties of WriterT. In most cases, the space leaks from WriterT won't turn out to be a huge problem, but it's good to know what's happening and why it comes about!

  1. This is a good way of thinking about a lot of performance problems in a lazy language: not in terms of 'steps', like in an imperative language, but in terms of dependencies among your subexpressions.