This time, I would like to talk about recursion schemes, the idea, the benefits and an real world example.
The materials I follow including the famous series of posts[1], the data-fix package[2] and a real world type-inference project[3].
Briefly speaking, recursion-schemes is another style of expressing recursive data type in functional programming langauges(such as haskell). It takes advantage of fixpoint and provides one more abstract level of graph traversal mechanism compared with "naive" definition of recursive data type.
I would not like to discuss at a very detailed level, since the the famous posts[1] have already demonstrated it clearly, a curious mind should consider referring to it. Let's instead start with a very brief summary of the idea, then jump directly into its usage in a real world language project.
The definition of the example AST(from the post [1]) is:
data Lit
= StrLit String
| IntLit Int
| Ident String
deriving (Show, Eq)
data Expr
= Index Expr Expr
| Call Expr [Expr]
| Unary String Expr
| Binary Expr String Expr
| Paren Expr
| Literal Lit
deriving (Show, Eq)
This is the "naive" definition and would be familiar to any haskellers. The trick is that the constructors of Expr could also take the Expr datatype as its parameters, so that a recursive datatype is declared.
The recursion-schemes style, however, define the above Expr in this way:
-- use TemplateHaskell only for better support of `show` function
{-# LANGUAGE TemplateHaskell #-}
-- optional, only to demonstrate ExprF could be a Functor
{-# LANGUAGE DeriveFunctor #-}
-- for better demonstration, we will use `data-fix`
import Data.Fix
-- better support of `show` function
import Text.Show.Deriving
data Lit
= StrLit String
| IntLit Int
| Ident String
deriving (Show, Eq)
data ExprF a
= Index a a
| Call a [a]
| Unary String a
| Binary a String a
| Paren a
| Literal Lit
deriving (Show, Eq, Functor)
-- better support of `show` function
$(deriveShow1 ''ExprF)
type Expr = Fix ExprF
-- -- the Fix datatype is just
-- newtype Fix f = Fix { unFix :: f (Fix f) }
d1 = Fix (Literal (IntLit 10)) :: Expr
d2 = Fix $ Paren $ Fix (Literal (IntLit 10)) :: Expr
As you can see, Expr is defined using the fixpoint of ExprF while ExprF just represents our AST in a non-recursive way: it can be hardly seen that ExprF a could build any tree-like structures in any way.
However, using the fixpoint trick, this becomes possible.
Take a look at the definition of Fix(as we import from Data.Fix in the above example):
newtype Fix f = Fix { unFix :: f (Fix f) }
The trick is to deploy the recursive definition at the FixPoint level, where f just works as a wrapper over another Fix f datatype.
To better understand this, consider a marginal case:
Fix (Literal (IntLit 10)) :: Expr
Will this pass the type-check?
Yes! because (Literal (IntLit 10)) can be type-checked with ExprF a where a stands as any type variables. More specifically:
l1 = Literal (IntLit 10) :: ExprF a
l2 = Literal (IntLit 10) :: ExprF ()
l3 = Literal (IntLit 10) :: ExprF (ExprF a)
l4 = Literal (IntLit 10) :: ExprF (Fix ExprF)
can all be type-checked.
With this, Fix l4 :: Fix ExprF pass successfully.
Then, immediately, you can see
d2 = Fix $ Paren $ Fix (Literal (IntLit 10)) :: Expr
can also pass type-check, since Fix of f (Fix f) belongs to Fix f
So, one reason why recursion-schemes works is that the marginal case of ExprF a can be type-checked with any type variable a, or, put it in another way, marginal case such as Literal (IntLit 10) brings an extra type variable with it.
Knowing this is important, since most of the traversal mechanism provided by recursion-schemes just use functions defined on marginal cases(all other cases are automatically supported without any extra code!)
Consider a folding case from the post[1]:
import Text.PrettyPrint (Doc)
import qualified Text.PrettyPrint as P
ten, add, call :: Expr
ten = Fix (Literal (IntLit 10))
add = Fix (Literal (Ident "add" ))
call = Fix (Call add [ten, ten]) --add(10, 10)
type Algebra f a = f a -> a
prettyPrint :: Algebra ExprF Doc
prettyPrint (Literal (IntLit i)) = P.int i
prettyPrint (Literal (Ident s)) = P.text s
prettyPrint (Call f as) =
---f(a,b...)
f <> P.parens (mconcat (P.punctuate "," as))
prettyPrint (Index it idx) =
---a[b]
it <> P.brackets idx
prettyPrint (Unary op it) =
---op x
(P.text op) <> it
prettyPrint (Binary l op r) =
---lhs op rhs
l <> (P.text op) <> r
prettyPrint (Paren exp) =
---(op)
P.parens exp
v = foldFix prettyPrint call
This will result in add(10,10).
The magic foldFix provided by fix-data has this type signature:
foldFix :: Functor f => (f a -> a) -> Fix f -> a
This function is named cata in the post, and the implementation looks like this:
import Control.Arrow
-- foldFix is also named cata, meaning catamorphism
foldFix :: (Functor f) => Algebra f a -> Fix f -> a
foldFix fn = unFix >>> fmap (foldFix fn) >>> fn
What this function does, is
firstly using unFix to fetch the content f (Fix f)
then apply fmap (foldFix fn) on it, this is just the recursive call. After this step, f (Fix f) will become f a
finally apply fn on the result, transform f a into a
Looks familiar? Yes, the foldr function!
As its name indicates, foldFix is the Fix version of fold function.
In fact, we can write any recursive function we like as in "naive" AST definitions, foldr, foldl, map, mapM what ever!
Now, let's go back to the example:
Consider the implementation of prettyPrint (Index it idx):
prettyPrint (Index it idx) = it <> P.brackets idx
We know that prettyPrint has type signature prettyPrint :: ExprF Doc -> Doc, it is easily understood that there should be <> in between the two object it and P.brackets idx since (<>) :: Doc -> Doc -> Doc. But wait, do it and idx really belong to the type Doc?
The answer is Yes, just like what foldr function does.
More specifically, consider the two marginal cases:
prettyPrint (Literal (IntLit i)) = P.int i
prettyPrint (Literal (Ident s)) = P.text s
By looking at this, you can easily see the Expr AST will be reduce into Doc starting from the marginal cases(leaves) and propogating throught the traversal procedure, and when it reaches the Index a a node, a is already Doc.
The key point to smoothly understand this, again, is to remind that (Literal (IntLit i)) can be safely type-checkered as Expr, just like what a "naive" AST definition does.
As we have known a little bit of recursion-schemes, its definition and the basic idea, we move on to a real world example: hindley-milner-type-check[3]
As its name suggests, hindley-milner-type-check is a library that provide HM-based type inference on an arbitray language based on lambda calculus. It brings a language using recursion-schemes style, and custom language should be able to map onto it in order to use its type-infer facility. Today, we mainly focus on its language definition and some folding cases of it.
The language is defined as follows:
import Data.Fix
import Data.Eq.Deriving
import Data.Ord.Deriving
import Text.Show.Deriving
-- | Term functor. The arguments are
-- loc for source code locations
-- v for variables
-- r for recursion
data TermF prim loc v r
= Var loc v -- ^ Variables.
| Prim loc prim -- ^ Primitives.
| App loc r r -- ^ Applications.
| Lam loc v r -- ^ Abstractions.
| Let loc (Bind loc v r) r -- ^ Let bindings.
| LetRec loc [Bind loc v r] r -- ^ Recursive let bindings
| AssertType loc r (Type loc v) -- ^ Assert type.
| Case loc r [CaseAlt loc v r] -- ^ case alternatives
| Constr loc v -- ^ constructor with tag
| Bottom loc -- ^ value of any type that means failed program.
deriving (Show, Eq, Functor, Foldable, Traversable, Data)
-- | Case alternatives
data CaseAlt loc v a = CaseAlt
{ caseAlt'loc :: loc
-- ^ source code location
, caseAlt'tag :: v
-- ^ tag of the constructor
, caseAlt'args :: [(loc, v)]
-- ^ arguments of the pattern matching
, caseAlt'rhs :: a
-- ^ right-hand side of the case-alternative
}
deriving (Show, Eq, Functor, Foldable, Traversable, Data)
-- | Local variable definition.
--
-- > let lhs = rhs in ...
data Bind loc var a = Bind
{ bind'loc :: loc -- ^ Source code location
, bind'lhs :: var -- ^ Variable name
, bind'rhs :: a -- ^ Definition (right-hand side)
} deriving (Show, Eq, Functor, Foldable, Traversable, Data)
$(deriveShow1 ''TermF)
$(deriveEq1 ''TermF)
$(deriveOrd1 ''TermF)
$(deriveShow1 ''Bind)
$(deriveEq1 ''Bind)
$(deriveOrd1 ''Bind)
$(deriveShow1 ''CaseAlt)
$(deriveEq1 ''CaseAlt)
$(deriveOrd1 ''CaseAlt)
Take a look at the TermF definition, it has three variables: loc for source code locations, v for variables, r for recursion.
loc is not important in our study, so we just ignore it, v is the variable we should registered into the AST and it just corresponds to the marginal cases we will discuss later, r is the recursion variable, and will be transformed into marginal cases during folding. Note that CaseAlt and Bind also have r as recursion variable in the definition of TermF, this is also a demonstration of how to define mutually recursive data type using recursion-schemes
Also the definition of Term based on TermF using fixpoint:
-- | The type of terms.
newtype Term prim loc v = Term { unTerm :: Fix (TermF prim loc v) }
deriving (Show, Eq, Data)
Now let's take a look at a simple folding example using recursion-schemes:
get free variables from an AST
-- | Get free variables of the term.
freeVars :: Ord v => Term prim loc v -> Set v
freeVars = foldFix go . unTerm
where
go = \case
Var _ v -> S.singleton v
Prim _ _ -> mempty
App _ a b -> mappend a b
Lam _ v a -> S.delete v a
Let _ bind body -> let lhs = S.singleton $ bind'lhs bind
in mappend (bind'rhs bind)
(body `S.difference` lhs)
LetRec _ binds body -> let
lhs = S.fromList $ fmap bind'lhs binds
in
(mappend (freeBinds binds) body)
`S.difference` lhs
AssertType _ a _ -> a
Case _ e alts -> mappend e (foldMap freeVarAlts alts)
Constr _ _ -> mempty
Bottom _ -> mempty
freeBinds = foldMap bind'rhs
freeVarAlts CaseAlt{..} =
caseAlt'rhs `S.difference`
(S.fromList $ fmap snd caseAlt'args)
Again, let's firstly focus on the two marginal cases: Var loc v and Prim loc prim.
The Var loc v case is simple, since v IS a free variable, we have
Var _ v -> S.singleton v
The Prim loc prim is even simpler, it has no variables inside
Prim _ _ -> mempty
Then, let's take a look at three non-marginal cases as an example(other cases follows the same rule): App loc r r, Lam loc v r and Let loc (Bind loc v r) r
The App loc r r case is simple, since application of r1 to r2 just results in free variables of r1 and r2 to merge. And remember that under foldFix, the r in App loc r r just reduce to the result of marginal cases(i.e. Set v), so that an mappend function is enough
App _ a b -> mappend a b
The Lam loc v r looks a little bit complex at first glance, since it has both v and r inside. But, remind that r will be reduced into Set v under foldFix, we can easily see we just need to do some set operation on r
Lam _ v a -> S.delete v a
We remove v from set a since v is binded.
The Let loc (Bind loc v r) r case is the most complex one, since it also refer to another datatype Bind. Remind the definition of Bind:
data Bind loc var a = Bind
{ bind'loc :: loc -- ^ Source code location
, bind'lhs :: var -- ^ Variable name
, bind'rhs :: a -- ^ Definition (right-hand side)
} deriving (Show, Eq, Functor, Foldable, Traversable, Data)
Looks complex? We just need to apply the strategy we used to analyze Term onto Bind.
Firstly, v in Bind loc v r is the variable corresponds to the marginal cases
Secondly, r is the recursion variable, and will be transformed into marginal cases during folding
So we use bind'lhs bind to fetch variables and use bind'rhs bind to fetch the Set of variables, that's it! The resulting function looks like this:
Let _ bind body ->
let
lhs = S.singleton $ bind'lhs bind
in
mappend (bind'rhs bind) (body `S.difference` lhs)
[1] series of posts: https://blog.sumtypeofway.com/posts/introduction-to-recursion-schemes.html
[2] data-fix package: https://hackage.haskell.org/package/data-fix-0.3.3/docs/Data-Fix.html
[3] type-inference project: https://github.com/anton-k/hindley-milner-type-check