It may look like a very old topic(around 2015). But I do spend some time on it, and it really helps me understand better about GADT
, its advantages in handling transformations, which we're always encountering in a lot of tasks, including Bayesian inference.
The key point is that, why do we need GADT
, as we already have ADT, what the benefit it may bring us? It takes me some time to fully understand this. And with the example of GADT
in Bayesian inference, it really makes it clear!
Before we get started, I should mention all the materials this post is based on. This is important, since it looks like there happened to be a discussion and practice of this idea around 2015(8 years earlier!), and currently the technique is not used any more by those people(maybe there are some drawbacks of GADT
which I haven't figured it out). So remember, it may be an old/discrepancy thing!
The first document I need to mention is the paper: Practical Probabilistic Programming with Monads, published in 2015, written by Adam ́Scibior, Zoubin Ghahramani and Andrew D. Gordon. Please refer to section 3.3.
The second document is the github repository monad-bayes, by the same people of the 2015 paper. But GADT
implementation is replaced by monad transformer in 2016. So you need to checkout the 5c1f926b6f1b8fdd382c74c0922b21817897faf2
commit to retrieve the original code. (you may also want to check the pr to see what's the reason of replacing GADTs
with monad transformer).
The third document is the talk given by Tikhon Jelvis, exactly following the original idea of the 2015 paper.
The basic idea is to define a probability monad in the following ways:
data Dist a where
Return :: a -> Dist a
Bind :: Dist b -> (b -> Dist a) -> Dist a
Primitive :: (Sampleable d) => d a -> Dist a
Conditional :: (a -> Prob) -> Dist a -> Dist a
Here, using GADT
, 4 types of constructors are built.
The first one is Return :: a -> Dist a
, that is, provided a data of type a
, and wrapped it with Return
. This is what we will see in a normal ADT definition like data Dist a = Dist a
.
The second one is Bind :: Dist b -> (b -> Dist a) -> Dist a
, this really is where the trick lies. What the Bind
does, is register a data of type Dist b
, and a function of type b -> Dist a
, and wrapping them into the data type Dist a
. So an instance of Dist a
could be another instance of type Dist b
combined with a function of type b -> Dist a
. This, together with Return :: a -> Dist a
, build a tree structure for probability graph model.
The third one is Primitive :: (Sampleable d) => d a -> Dist a
. This is wrapping a data of type (Sampleable d) => d a
into Dist a
, for the primitive distributions, such as categorical, multinomial, Gaussian, etc.
The final one is Conditional :: (a -> Prob) -> Dist a -> Dist a
. Just like in Bind
and Primitive
, it wraps a function a -> Prob
and a data of type Dist a
into Dist a
.
As you can see, all about GADT
is registering. It registers values and functions(of certain type) into a single type signature. And all the values and functions will not be evaluated until necessary(this is actually an implementation of lazy evaluation, and please do not confuse this with the built-in lazy evaluation in haskell
). So this way, you define a probability graph most conveniently.
So let's take a look at how and when it will be evaluated:
instance Sampleable Dist where
sample g (Return x) = x
sample g (Primitive d) = sample g d
sample g (Bind d f) = sample g1 $ f $ sample g2 d where
(g1, g2) = split g
sample g (Conditional c d) = error "Attempted to sample from a conditional distribution."
Here, the parameter g
refers to some random number generator, such as StdGen
provided in System.Random
. Let's explain sample function of each case:
sample g (Return x) = x
for Return
case just gives us the wrapped value in Return
case of Dist a
. This is the most common case we may encounter in ADT
definitions.
sample g (Primitive d) = sample g d
triggered the sample function bound to data d
, which is Sampleable
. (the data d a
is registered in Dist a
through GADT
)
sample g (Bind d f) = sample g1 $ f $ sample g2 d where (g1, g2) = split g
is the most interesting one. It firstly triggers the sample function bound to data d
, obtains a data x
of type b
. And then triggers the function f
by f x
, obtains a data y
of type Dist a
. And finally, triggered the sample function bound to Dist a
(this is a recursive call), return the samples obtained. The behavior is what we will always see in any probabilistic programming languages(PPL). Indeed, this is the essential of PPL
. Again, the data d
, function f
is registered in Dist a
through GADT
.
sample g (Conditional c d) = error "Attempted to sample from a conditional distribution."
just warns you that conditional declaration does not support sample function. Actually, during any inference, the original Dist a
data which contains conditional block
should be transformed into ones without conditional block
, using algorithms such as MCMC
, variational inference and so on. The kind of optimization before launch
behavior could be easily realized through GADT
(again, an implementation of lazy evaluation) in haskell
!
Compared with ADT
, what GADT
provides you, is a flexible way to wrap more types of data into you data type. You can not only register some type of data or function, but also register a combination of data and functions together, where the behaviors of this combination could be defined later in some trigger function. By doing this, you immediately obtain the power of expressing any computational graph into your single data type. That is amazing!
For the necessary of using GADT
instead of ADT
, considering rewriting the example shown above using normal ADT
. For example:
data Dist a where
Return :: a -> Dist a
Bind :: Dist b -> (b -> Dist a) -> Dist a
Primitive :: (Sampleable d) => d a -> Dist a
Conditional :: (a -> Prob) -> Dist a -> Dist a
should be rewritten as the follows:
data Dist a = Return a
| Bind (Dist b) (b -> Dist a)
| Primitive (d a)
| Conditional (Dist a)
This code will not compile, and will raise the following error:
error: Not in scope: type variable ‘b’
error: Not in scope: type variable ‘d’