Previously we have discussed how to use GADT to conveniently build computational graph for Bayesian models. Today, we discuss another topic related to GADT: Suppose we have already built a GADT, how could we make it an instance of Monad typeclass?

The reason we ask this question is as follows:

In conventional ADT, we can view a data type like data ADist a = ADist a as a simple wrapper about type variable a. Then a Monad could be easily defined as this:

instance Monad ADist where 
    return = ADist 
    (>>=) (ADist a) f = f a

That is, the bind function (>>=) just need to fetch the content in the ADist wrapper, and then apply function f on it.

However, when we see the case of GADT, it is not so apparently what kind of data it wraps inside(Or there are so many wrapped in it, which one do we really refer to when doing (>>=)?). For example:

data Dist a where
    Return      :: a -> Dist a
    Bind        :: Dist b -> (b -> Dist a) -> Dist a
    Primitive   :: (Sampleable d) => d a -> Dist a
    Conditional :: (a -> Prob) -> Dist a -> Dist a

The Bind and Conditional constructor just wraps multiple things(values or functions) inside, which one should be corresponding to value of type a?

The answer is still related to the idea we have demonstrated in previous post: GADT just do register job, and all the function/operation registered in are in lazy evaluation mode(again, do not confuse this with the built-in lazy evaluation mechanism in haskell). We can decide what kind of operation for (>>=) later in typeclass, and viewing GADT as a thunk. More specifically, see how the historical monad-bayes commit handling this problem:

instance Functor Dist where
    fmap  = liftM

instance Applicative Dist where
    pure  = return
    (<*>) = liftM2 ($)

instance Monad Dist where
    return = Return
    (>>=)  = Bind

The implementation of these three typeclass jointly defines all the necessary elements for a Monad. The trick is the (>>=) definition:

(>>=) = Bind

It is just another registering operation. That is to say, it doesn't matter what kind of data are in the Monad, we just assume that it is there, and registering upcoming function into the data type. Again, nothing really happened until some trigger function are called.

Something more interesting is the fmap function:

fmap = liftM

This function is defined as liftM in Control.Monad, and we have known the leftM function is defined based on (>>=) function. So the behavior of fmap now changes(unlike what we normally see). It is no longer some function that fetch data a from a Functor, then apply the function f on it, finally wrap it back in a Monad. It is again just a function registering operation: if assume a value a is stored in the monad, and registers a function f, such that, given the value, it returns another value f a, and finally, we call some constructor of Monad M to build a data of type M.

Let's demonstrate all above idea in a simple example:

newtype Explicit a = Explicit {toList :: [(a,Prob)]}
                deriving(Show)

normalize :: [(a,Prob)] -> [(a,Prob)]
normalize xs = map (second (/ norm)) xs where
    norm = sum $ map snd xs

instance Sampleable Explicit where
    sample g (Explicit xs) =
        pick xs $ fst $ randomR (0.0,1.0) g

class DiscreteDist d where
    categorical :: [(a,Prob)] -> d a


newtype Prob = Prob {toDouble :: Double}
    deriving (Show, Eq, Ord, Num, Fractional, Real, RealFrac, Floating, Random, Ext.Distribution Ext.StdUniform)

instance DiscreteDist Explicit where
    categorical = Explicit . normalize

instance DiscreteDist Dist where
    categorical = Primitive . (categorical :: [(a,Prob)] -> Explicit a)

type Samples a = [(a,Prob)]

resample :: Samples a -> Dist (Samples a)
resample xs = sequence $ replicate n $ fmap (,1) $ categorical xs where
    n = length xs

In resample function, what categorical xs does is to return a data of type Dist a. fmap (,1) $ categorical xs will transform this data into type of Dist (a, Prob) just using fmap function we discussed above. Not that what categorical xs really gives us is something looks-like this:

Primitive (Explicit {toList=xs :: Sample a})

It is not clearly how the fmap would work for it by looking at it directly. However, if remind that fmap here is just about registering, it is clear. Since by calling trigger function such as sample, the Primitive (Explicit {toList=xs :: Sample a}) data will certainly give us a value of type a, and all the link is now connected!

This is the beauty of GADT!

CC BY-SA 4.0 Septimia Zenobia. Last modified: May 28, 2024. Website built with Franklin.jl and the Julia programming language.