This time, I would like to talk about recursion schemes, the idea, the benefits and an real world example.
The materials I follow including the famous series of posts[1], the data-fix package[2] and a real world type-inference project[3].
Briefly speaking, recursion-schemes
is another style of expressing recursive data type in functional programming langauges(such as haskell
). It takes advantage of fixpoint
and provides one more abstract level of graph traversal mechanism compared with "naive" definition of recursive data type.
I would not like to discuss at a very detailed level, since the the famous posts[1] have already demonstrated it clearly, a curious mind should consider referring to it. Let's instead start with a very brief summary of the idea, then jump directly into its usage in a real world language project.
The definition of the example AST(from the post [1]) is:
data Lit
= StrLit String
| IntLit Int
| Ident String
deriving (Show, Eq)
data Expr
= Index Expr Expr
| Call Expr [Expr]
| Unary String Expr
| Binary Expr String Expr
| Paren Expr
| Literal Lit
deriving (Show, Eq)
This is the "naive" definition and would be familiar to any haskellers. The trick is that the constructors of Expr
could also take the Expr
datatype as its parameters, so that a recursive datatype is declared.
The recursion-schemes
style, however, define the above Expr
in this way:
-- use TemplateHaskell only for better support of `show` function
{-# LANGUAGE TemplateHaskell #-}
-- optional, only to demonstrate ExprF could be a Functor
{-# LANGUAGE DeriveFunctor #-}
-- for better demonstration, we will use `data-fix`
import Data.Fix
-- better support of `show` function
import Text.Show.Deriving
data Lit
= StrLit String
| IntLit Int
| Ident String
deriving (Show, Eq)
data ExprF a
= Index a a
| Call a [a]
| Unary String a
| Binary a String a
| Paren a
| Literal Lit
deriving (Show, Eq, Functor)
-- better support of `show` function
$(deriveShow1 ''ExprF)
type Expr = Fix ExprF
-- -- the Fix datatype is just
-- newtype Fix f = Fix { unFix :: f (Fix f) }
d1 = Fix (Literal (IntLit 10)) :: Expr
d2 = Fix $ Paren $ Fix (Literal (IntLit 10)) :: Expr
As you can see, Expr
is defined using the fixpoint of ExprF
while ExprF
just represents our AST in a non-recursive way: it can be hardly seen that ExprF a
could build any tree-like structures in any way.
However, using the fixpoint trick, this becomes possible.
Take a look at the definition of Fix
(as we import from Data.Fix
in the above example):
newtype Fix f = Fix { unFix :: f (Fix f) }
The trick is to deploy the recursive definition at the FixPoint
level, where f
just works as a wrapper over another Fix f
datatype.
To better understand this, consider a marginal case:
Fix (Literal (IntLit 10)) :: Expr
Will this pass the type-check?
Yes! because (Literal (IntLit 10))
can be type-checked with ExprF a
where a
stands as any type variables. More specifically:
l1 = Literal (IntLit 10) :: ExprF a
l2 = Literal (IntLit 10) :: ExprF ()
l3 = Literal (IntLit 10) :: ExprF (ExprF a)
l4 = Literal (IntLit 10) :: ExprF (Fix ExprF)
can all be type-checked.
With this, Fix l4 :: Fix ExprF
pass successfully.
Then, immediately, you can see
d2 = Fix $ Paren $ Fix (Literal (IntLit 10)) :: Expr
can also pass type-check, since Fix
of f (Fix f)
belongs to Fix f
So, one reason why recursion-schemes
works is that the marginal case of ExprF a
can be type-checked with any type variable a
, or, put it in another way, marginal case such as Literal (IntLit 10)
brings an extra type variable with it.
Knowing this is important, since most of the traversal mechanism provided by recursion-schemes
just use functions defined on marginal cases(all other cases are automatically supported without any extra code!)
Consider a folding case from the post[1]:
import Text.PrettyPrint (Doc)
import qualified Text.PrettyPrint as P
ten, add, call :: Expr
ten = Fix (Literal (IntLit 10))
add = Fix (Literal (Ident "add" ))
call = Fix (Call add [ten, ten]) --add(10, 10)
type Algebra f a = f a -> a
prettyPrint :: Algebra ExprF Doc
prettyPrint (Literal (IntLit i)) = P.int i
prettyPrint (Literal (Ident s)) = P.text s
prettyPrint (Call f as) =
---f(a,b...)
f <> P.parens (mconcat (P.punctuate "," as))
prettyPrint (Index it idx) =
---a[b]
it <> P.brackets idx
prettyPrint (Unary op it) =
---op x
(P.text op) <> it
prettyPrint (Binary l op r) =
---lhs op rhs
l <> (P.text op) <> r
prettyPrint (Paren exp) =
---(op)
P.parens exp
v = foldFix prettyPrint call
This will result in add(10,10)
.
The magic foldFix
provided by fix-data
has this type signature:
foldFix :: Functor f => (f a -> a) -> Fix f -> a
This function is named cata
in the post, and the implementation looks like this:
import Control.Arrow
-- foldFix is also named cata, meaning catamorphism
foldFix :: (Functor f) => Algebra f a -> Fix f -> a
foldFix fn = unFix >>> fmap (foldFix fn) >>> fn
What this function does, is
firstly using unFix
to fetch the content f (Fix f)
then apply fmap (foldFix fn)
on it, this is just the recursive call. After this step, f (Fix f)
will become f a
finally apply fn
on the result, transform f a
into a
Looks familiar? Yes, the foldr
function!
As its name indicates, foldFix
is the Fix
version of fold function.
In fact, we can write any recursive function we like as in "naive" AST definitions, foldr
, foldl
, map
, mapM
what ever!
Now, let's go back to the example:
Consider the implementation of prettyPrint (Index it idx)
:
prettyPrint (Index it idx) = it <> P.brackets idx
We know that prettyPrint
has type signature prettyPrint :: ExprF Doc -> Doc
, it is easily understood that there should be <>
in between the two object it
and P.brackets idx
since (<>) :: Doc -> Doc -> Doc
. But wait, do it
and idx
really belong to the type Doc
?
The answer is Yes, just like what foldr
function does.
More specifically, consider the two marginal cases:
prettyPrint (Literal (IntLit i)) = P.int i
prettyPrint (Literal (Ident s)) = P.text s
By looking at this, you can easily see the Expr
AST will be reduce into Doc
starting from the marginal cases(leaves) and propogating throught the traversal procedure, and when it reaches the Index a a
node, a
is already Doc
.
The key point to smoothly understand this, again, is to remind that (Literal (IntLit i))
can be safely type-checkered as Expr
, just like what a "naive" AST definition does.
As we have known a little bit of recursion-schemes
, its definition and the basic idea, we move on to a real world example: hindley-milner-type-check[3]
As its name suggests, hindley-milner-type-check
is a library that provide HM-based
type inference on an arbitray language based on lambda calculus
. It brings a language using recursion-schemes
style, and custom language should be able to map onto it in order to use its type-infer facility. Today, we mainly focus on its language definition and some folding cases of it.
The language is defined as follows:
import Data.Fix
import Data.Eq.Deriving
import Data.Ord.Deriving
import Text.Show.Deriving
-- | Term functor. The arguments are
-- loc for source code locations
-- v for variables
-- r for recursion
data TermF prim loc v r
= Var loc v -- ^ Variables.
| Prim loc prim -- ^ Primitives.
| App loc r r -- ^ Applications.
| Lam loc v r -- ^ Abstractions.
| Let loc (Bind loc v r) r -- ^ Let bindings.
| LetRec loc [Bind loc v r] r -- ^ Recursive let bindings
| AssertType loc r (Type loc v) -- ^ Assert type.
| Case loc r [CaseAlt loc v r] -- ^ case alternatives
| Constr loc v -- ^ constructor with tag
| Bottom loc -- ^ value of any type that means failed program.
deriving (Show, Eq, Functor, Foldable, Traversable, Data)
-- | Case alternatives
data CaseAlt loc v a = CaseAlt
{ caseAlt'loc :: loc
-- ^ source code location
, caseAlt'tag :: v
-- ^ tag of the constructor
, caseAlt'args :: [(loc, v)]
-- ^ arguments of the pattern matching
, caseAlt'rhs :: a
-- ^ right-hand side of the case-alternative
}
deriving (Show, Eq, Functor, Foldable, Traversable, Data)
-- | Local variable definition.
--
-- > let lhs = rhs in ...
data Bind loc var a = Bind
{ bind'loc :: loc -- ^ Source code location
, bind'lhs :: var -- ^ Variable name
, bind'rhs :: a -- ^ Definition (right-hand side)
} deriving (Show, Eq, Functor, Foldable, Traversable, Data)
$(deriveShow1 ''TermF)
$(deriveEq1 ''TermF)
$(deriveOrd1 ''TermF)
$(deriveShow1 ''Bind)
$(deriveEq1 ''Bind)
$(deriveOrd1 ''Bind)
$(deriveShow1 ''CaseAlt)
$(deriveEq1 ''CaseAlt)
$(deriveOrd1 ''CaseAlt)
Take a look at the TermF
definition, it has three variables: loc
for source code locations, v
for variables, r
for recursion.
loc
is not important in our study, so we just ignore it, v
is the variable we should registered into the AST and it just corresponds to the marginal cases we will discuss later, r
is the recursion variable, and will be transformed into marginal cases during folding. Note that CaseAlt
and Bind
also have r
as recursion variable in the definition of TermF
, this is also a demonstration of how to define mutually recursive data type using recursion-schemes
Also the definition of Term
based on TermF
using fixpoint:
-- | The type of terms.
newtype Term prim loc v = Term { unTerm :: Fix (TermF prim loc v) }
deriving (Show, Eq, Data)
Now let's take a look at a simple folding example using recursion-schemes
:
get free variables from an AST
-- | Get free variables of the term.
freeVars :: Ord v => Term prim loc v -> Set v
freeVars = foldFix go . unTerm
where
go = \case
Var _ v -> S.singleton v
Prim _ _ -> mempty
App _ a b -> mappend a b
Lam _ v a -> S.delete v a
Let _ bind body -> let lhs = S.singleton $ bind'lhs bind
in mappend (bind'rhs bind)
(body `S.difference` lhs)
LetRec _ binds body -> let
lhs = S.fromList $ fmap bind'lhs binds
in
(mappend (freeBinds binds) body)
`S.difference` lhs
AssertType _ a _ -> a
Case _ e alts -> mappend e (foldMap freeVarAlts alts)
Constr _ _ -> mempty
Bottom _ -> mempty
freeBinds = foldMap bind'rhs
freeVarAlts CaseAlt{..} =
caseAlt'rhs `S.difference`
(S.fromList $ fmap snd caseAlt'args)
Again, let's firstly focus on the two marginal cases: Var loc v
and Prim loc prim
.
The Var loc v
case is simple, since v
IS a free variable, we have
Var _ v -> S.singleton v
The Prim loc prim
is even simpler, it has no variables inside
Prim _ _ -> mempty
Then, let's take a look at three non-marginal cases as an example(other cases follows the same rule): App loc r r
, Lam loc v r
and Let loc (Bind loc v r) r
The App loc r r
case is simple, since application of r1
to r2
just results in free variables of r1
and r2
to merge. And remember that under foldFix
, the r
in App loc r r
just reduce to the result of marginal cases(i.e. Set v
), so that an mappend
function is enough
App _ a b -> mappend a b
The Lam loc v r
looks a little bit complex at first glance, since it has both v
and r
inside. But, remind that r
will be reduced into Set v
under foldFix
, we can easily see we just need to do some set operation on r
Lam _ v a -> S.delete v a
We remove v
from set a
since v
is binded.
The Let loc (Bind loc v r) r
case is the most complex one, since it also refer to another datatype Bind
. Remind the definition of Bind
:
data Bind loc var a = Bind
{ bind'loc :: loc -- ^ Source code location
, bind'lhs :: var -- ^ Variable name
, bind'rhs :: a -- ^ Definition (right-hand side)
} deriving (Show, Eq, Functor, Foldable, Traversable, Data)
Looks complex? We just need to apply the strategy we used to analyze Term
onto Bind
.
Firstly, v
in Bind loc v r
is the variable corresponds to the marginal cases
Secondly, r
is the recursion variable, and will be transformed into marginal cases during folding
So we use bind'lhs bind
to fetch variables and use bind'rhs bind
to fetch the Set
of variables, that's it! The resulting function looks like this:
Let _ bind body ->
let
lhs = S.singleton $ bind'lhs bind
in
mappend (bind'rhs bind) (body `S.difference` lhs)
[1] series of posts: https://blog.sumtypeofway.com/posts/introduction-to-recursion-schemes.html
[2] data-fix package: https://hackage.haskell.org/package/data-fix-0.3.3/docs/Data-Fix.html
[3] type-inference project: https://github.com/anton-k/hindley-milner-type-check