HPR3028: Monads and Haskell

Published: March 11, 2020, midnight

This is basically a transcript of the post I wrote on the subject which I host here It has a bit more than what I talked about

Join in Haskell

join is a monadic operation, instead of working only on lists, it works on monads and has the signature:

   join :: Monad m => m ( m a ) -> (m a)

In effect it joins or merges two successive monad applications into a single monad application. But join is not part of the canonical monad definition, which is given by:

   return ::  Monad m => a -> m a` ; and`
   (>>=) :: Monad m => m a -> ( a -> m b ) -> m b

A good or rather trivial way to think of the relationship between return join and (>>=) is that in essence, since each monad is a functor, then what (>>=) does is that it maps the second argument over the first argument, and then uses join to merge the two applications of the monad constructor, i.e.:

  (x >>= f) = join $ fmap f x

However, join needs to be constructed from return and (>>=). The naive solution is that we want to trick (>>=) to let us apply a function that does not pile up yet another m onto our initial type m (m a) and surprisingly, this will actually work if we let

   join x = (x >>= id)

Initially this is surprising since id has the signature (c -> c) instead of the necessary a -> m b! However, when c is not an atomic type, but rather of the form m d for some (maybe atomic) type d, then we actually have the signature m d -> m d, and if we bind type a to m d and type b to d, we obtain id with actual type signatuare a -> m b, and it can indeed be used as the second argument of (>>=), and everything actually makes sense.

Now style is important and so we can do an eta reduction on this, to get a point-free implementation by simply binding the second argument of >>=:

   join = (>>= id)

This is all fine and well for the type number, and it does work, but it's also important to understand how it works, so let's see it in a simple example, using the Maybe monad. So let's start by refreshing the implementation of the monad instance:

instance Monad Maybe where

   (>>=) :: Maybe a -> (a -> Maybe b) -> Maybe b
   (>>=) Nothing _ = Nothing
   (>>=) (Just x) f = f x

   return :: a -> Maybe a
   return = Just

So let's now go through the successive bindings when performing (>>=id):

   Just x = Just (Just 2) => x = Just 2
   Just x >>= id = id x = x
   x = Just 2

This example is pretty much verbatim the same thing for the Either monad and many other monads that follow a similar principle, so let's look at a bit more of a complex example. Let's look at the List monad:

instance Monad List where

   (>>=) :: [a] -> (a -> [b]) -> [b]
   (>>=) [] _ = []
   (>>=) (x:xs) f = f x ++ (xs >>= f)

   return :: a -> [a]
   return = (:[])

Following the bindings again we have

   [[2,3],[4]] = (x:xs) => x = [2,3] ; xs = [[4]]
   (x:xs) >>= id = (id x) ++ (xs >>= id) = x ++ (xs >>= id)
   [[4]] = (y:ys) => y = [4] ; ys = []
   (y:ys) >>= id = y ++ (ys >>= id)
   ys = [] => (ys >>= id) = []
   => (y:ys) >>= id  =  [4] ++ [] = [4]
   => xs >>= id = [4]
   => (x:xs) >>= id = x ++ [4] = [2,3] ++ [4] = [2,3,4]

So all of this is to say that join actually does what one expects it to do on a list of lists. It joins them one by one into a single list.

The Associativity Law

The associativity law of a monad can be quite confusing, after all it takes the form:

   (m >>= f) >>= g == m >>= ( x -> f x >>= g )

While not too complicated to understand it is difficult to see how it relates to a usual associativity law, which follows the form a * (b * c) = (a * b) * c. To recover the associativity the usual explanation is that one has to see it in terms of the monadic function composition (>=>) and while this is a valid way of doing so, I like to decompose things in terms of fmap, (>>=) and join.

So let's use what we did previously on the associativity law, starting on the left side, and replacing the atomic-looking type m with the dependent type M z:

   (M z >>= f) >>= g
   = join (fmap f (M z)) >>= g
   = join (fmap g $ join ( fmap f (M z)))
   = join (fmap g $ join ( M (f z) )

and the right-hand side:

   M z >>= ( x -> f x >>= g )
   = M z >>= ( x -> join $ fmap g (f x))
   = join $ fmap ( x -> join $ fmap g (f x)) (M z)
   = join $ M ( x -> join $ fmap g (f x)) z
   = join $ M (join $ fmap g (f z))

And essentially what it says is that if I have two functions f and g I can either do it in an orderly fashion where I apply them sequentially with join . fmap f followed by join . fmap g, or I can apply them both within the constructor and then join once within the constructor and one outside the constructor, and I should get the same thing. In fact if you replace f and g with id; this is the associativity law for monads as usually presented in a standard category theory textbook such as Mac Lane's Categories for the Working Mathematician (and recovering this was the whole point of this exercise in the first place).