Ganesh Sittampalam (hsenag) wrote,
Ganesh Sittampalam

Restricted monads in Haskell

I was playing around with restricted monads and came up with the following. It seems really simple, so I was wondering if it's either already known or obvious?

The restricted monad problem is well-known in Haskell. We have some type constructor Foo and some restriction Restr, such that Foo is a monad, but only for contained types that are members of Restr. We can't make Foo an instance of Monad, because in the normal Monad class the types of return and (>>=) are fully polymorphic in the contained type. This in turn blocks us from using do-notation with our "monad". We can get round this using NoImplicitPrelude in GHC, but that's rather messy and means that normal Monads don't work properly in that module.

For concreteness, suppose that Restr is actually Ord, but we could use anything. We'll parameterise over the actual "monad" type, so we don't need to decide on that yet, but I have the usual Set example in mind.

First, let's define a restricted monad class:
class OrdMonad m where
  ordReturn :: Ord a => a -> m a
  ordBind :: (Ord a, Ord b) => m a -> (a -> m b) -> m b

Just to keep things concrete, obviously Set is a member of this:
instance OrdMonad Set where
  ordReturn = Set.singleton
  s `ordBind` f = Set.fold (\v ret -> f v `Set.union` ret) Set.empty s

Now, how can I make a monad from this? Let's start by defining a new type constructor, GADT-style. We intend to apply this type constructor to our OrdMonad instance.
data AsMonad m a where

Now we need some data constructors. Firstly we want to be able to embed "proper" OrdMonads. Here we'll need the full power of the GADTs extension, i.e. restricted return types:
  Embed :: (OrdMonad m, Ord a) => m a -> AsMonad m a

OK so far, but what we're really after is a way to implement return and (>>=). Well, let's take the easy way out:
  Return :: OrdMonad m => a -> AsMonad m a
  Bind :: OrdMonad m => AsMonad m a -> (a -> AsMonad m b) -> AsMonad m b

Now we can implement Monad trivially (I'll ignore fail, but it's not hard to add):
instance OrdMonad m => Monad (AsMonad m) where
  return = Return
  (>>=) = Bind

That was a nice bit of sleight-of-hand, but did it actually help? We've just delayed the problem till later.

Well, actually it does help. "Later", what we'll want to do is get back to our m a type from AsMonad m a. But at this point we can restrict a to being in Ord. What we want is a function unEmbed:
unEmbed :: Ord a => AsMonad m a -> m a

The Embed case of unEmbed is easy:
unEmbed (Embed m) = m

Since we've restricted a, the Return case is easy too:
unEmbed (Return v) = ordReturn v

Now for Bind. Let's split that up into cases based on what the left-hand argument is. Yes, I know this seems like delaying the inevitable, that's how it felt to me too!

If the left-hand argument is Embed, then both a and b are in Ord. So we can call unEmbed recursively and use ordBind:
unEmbed (Bind (Embed m) f) = m `ordBind` (unEmbed . f)

For Return, one of the monad laws applies:
unEmbed (Bind (Return v) f) = unEmbed (f v)

Now for the Bind case. My initial assumption when I was writing this code was that I'd be trapped in a loop, only able to break out the left argument of the inner Bind into yet more cases. Then I realised that actually we can just bring the monad laws to bear again:
unEmbed (Bind (Bind m f) g) = unEmbed (Bind m (\x -> Bind (f x) g)))

And, well, that's it. We can use do-notation on the AsMonad type, and move freely between that and the base type using Embed and unEmbed.
MonadPlus is a simple addition along the same lines:
class OrdMonad m => OrdMonadPlus m where
  ordMZero :: Ord a => m a
  ordMPlus :: Ord a => m a -> m a -> m a

instance OrdMonadPlus Set where
  ordMZero = Set.empty
  ordMPlus = Set.union

data AsMonad m a where
  MZero :: OrdMonadPlus m => AsMonad m a
  MPlus :: OrdMonadPlus m => AsMonad m a -> AsMonad m a -> AsMonad m a

instance OrdMonadPlus m => MonadPlus (AsMonad m) where
  mzero = MZero
  mplus = MPlus

unEmbed :: Ord a => AsMonad m a -> m a
unEmbed MZero = ordMZero
unEmbed (MPlus m1 m2) = ordMPlus (unEmbed m1) (unEmbed m2)
unEmbed (Bind MZero f) = unEmbed MZero
unEmbed (Bind (MPlus m1 m2) f) = unEmbed (MPlus (Bind m1 f) (Bind m2 f))

Here's some test code:
newtype Wrap a = Wrap { unWrap :: a } -- not an Ord even if a is

test1 = unEmbed $ do x <- Embed $ Set.fromList [6, 2, 3]
                     do y <- return (Wrap x)
                        z <- Embed $ Set.fromList [1..2]
                        guard (unWrap y < 5)
                        return (unWrap y + z)
                        return 10

One annoyance is that we can't parametrise over typeclasses (at least not nicely), so we can't make AsMonad fully general, instead we need one for each restriction.

Finally, if we are willing and able to add extra constructors to an existing type, I think it should be possible to directly make that type into a Monad using the same approach.

The closest thing I've seen to this before is something like this: It's the same sort of approach, but I don't think it generalises to arbitrary restricted monads in the same way as this.
Tags: haskell
  • Post a new comment


    default userpic

    Your reply will be screened

    Your IP address will be recorded 

    When you submit the form an invisible reCAPTCHA check will be performed.
    You must follow the Privacy Policy and Google Terms of use.