This time, I would like to talk about recursion schemes, the idea, the benefits and an real world example.

The materials I follow including the famous series of posts[1], the data-fix package[2] and a real world type-inference project[3].

the definition

Briefly speaking, recursion-schemes is another style of expressing recursive data type in functional programming langauges(such as haskell). It takes advantage of fixpoint and provides one more abstract level of graph traversal mechanism compared with "naive" definition of recursive data type.

I would not like to discuss at a very detailed level, since the the famous posts[1] have already demonstrated it clearly, a curious mind should consider referring to it. Let's instead start with a very brief summary of the idea, then jump directly into its usage in a real world language project.

The definition of the example AST(from the post [1]) is:

data Lit
  = StrLit String
  | IntLit Int
  | Ident String
  deriving (Show, Eq)

data Expr
  = Index Expr Expr
  | Call Expr [Expr]
  | Unary String Expr
  | Binary Expr String Expr
  | Paren Expr
  | Literal Lit
  deriving (Show, Eq)

This is the "naive" definition and would be familiar to any haskellers. The trick is that the constructors of Expr could also take the Expr datatype as its parameters, so that a recursive datatype is declared.

The recursion-schemes style, however, define the above Expr in this way:

-- use TemplateHaskell only for better support of `show` function
{-# LANGUAGE TemplateHaskell #-}
-- optional, only to demonstrate ExprF could be a Functor
{-# LANGUAGE DeriveFunctor #-}
-- for better demonstration, we will use `data-fix`
import Data.Fix 
-- better support of `show` function
import Text.Show.Deriving

data Lit
  = StrLit String
  | IntLit Int
  | Ident String
  deriving (Show, Eq)

data ExprF a
  = Index a a
  | Call a [a]
  | Unary String a
  | Binary a String a
  | Paren a
  | Literal Lit
  deriving (Show, Eq, Functor)

-- better support of `show` function
$(deriveShow1 ''ExprF)

type Expr = Fix ExprF

-- -- the Fix datatype is just 
-- newtype Fix f = Fix { unFix :: f (Fix f) }

d1 = Fix (Literal (IntLit 10)) :: Expr
d2 = Fix $ Paren $ Fix (Literal (IntLit 10)) :: Expr

As you can see, Expr is defined using the fixpoint of ExprF while ExprF just represents our AST in a non-recursive way: it can be hardly seen that ExprF a could build any tree-like structures in any way.

However, using the fixpoint trick, this becomes possible.

Take a look at the definition of Fix(as we import from Data.Fix in the above example):

newtype Fix f = Fix { unFix :: f (Fix f) }

The trick is to deploy the recursive definition at the FixPoint level, where f just works as a wrapper over another Fix f datatype.

To better understand this, consider a marginal case:

Fix (Literal (IntLit 10)) :: Expr

Will this pass the type-check?

Yes! because (Literal (IntLit 10)) can be type-checked with ExprF a where a stands as any type variables. More specifically:

l1 = Literal (IntLit 10) :: ExprF a
l2 = Literal (IntLit 10) :: ExprF ()
l3 = Literal (IntLit 10) :: ExprF (ExprF a)
l4 = Literal (IntLit 10) :: ExprF (Fix ExprF)

can all be type-checked.

With this, Fix l4 :: Fix ExprF pass successfully.

Then, immediately, you can see

d2 = Fix $ Paren $ Fix (Literal (IntLit 10)) :: Expr

can also pass type-check, since Fix of f (Fix f) belongs to Fix f

So, one reason why recursion-schemes works is that the marginal case of ExprF a can be type-checked with any type variable a, or, put it in another way, marginal case such as Literal (IntLit 10) brings an extra type variable with it.

Knowing this is important, since most of the traversal mechanism provided by recursion-schemes just use functions defined on marginal cases(all other cases are automatically supported without any extra code!)

Consider a folding case from the post[1]:

import Text.PrettyPrint (Doc)
import qualified Text.PrettyPrint as P

ten, add, call :: Expr
ten  = Fix (Literal (IntLit 10))
add  = Fix (Literal (Ident "add" ))
call = Fix (Call add [ten, ten]) --add(10, 10)

type Algebra f a = f a -> a
prettyPrint :: Algebra ExprF Doc
prettyPrint (Literal (IntLit i)) = P.int i
prettyPrint (Literal (Ident s)) = P.text s
prettyPrint (Call f as)     = 
  ---f(a,b...)
  f <> P.parens (mconcat (P.punctuate "," as))  
prettyPrint (Index it idx)  = 
  ---a[b]
  it <> P.brackets idx                
prettyPrint (Unary op it)   = 
  ---op x
  (P.text op) <> it                   
prettyPrint (Binary l op r) = 
  ---lhs op rhs
  l <> (P.text op) <> r               
prettyPrint (Paren exp)     = 
  ---(op)
  P.parens exp                        

v = foldFix prettyPrint call

This will result in add(10,10).

The magic foldFix provided by fix-data has this type signature:

foldFix :: Functor f => (f a -> a) -> Fix f -> a

This function is named cata in the post, and the implementation looks like this:

import Control.Arrow
-- foldFix is also named cata, meaning catamorphism
foldFix :: (Functor f) => Algebra f a -> Fix f -> a
foldFix fn = unFix >>> fmap (foldFix fn) >>> fn

What this function does, is

Looks familiar? Yes, the foldr function!

As its name indicates, foldFix is the Fix version of fold function.

In fact, we can write any recursive function we like as in "naive" AST definitions, foldr, foldl, map, mapM what ever!

Now, let's go back to the example:

Consider the implementation of prettyPrint (Index it idx):

prettyPrint (Index it idx)  = it <> P.brackets idx

We know that prettyPrint has type signature prettyPrint :: ExprF Doc -> Doc, it is easily understood that there should be <> in between the two object it and P.brackets idx since (<>) :: Doc -> Doc -> Doc. But wait, do it and idx really belong to the type Doc?

The answer is Yes, just like what foldr function does.

More specifically, consider the two marginal cases:

prettyPrint (Literal (IntLit i)) = P.int i
prettyPrint (Literal (Ident s)) = P.text s

By looking at this, you can easily see the Expr AST will be reduce into Doc starting from the marginal cases(leaves) and propogating throught the traversal procedure, and when it reaches the Index a a node, a is already Doc.

The key point to smoothly understand this, again, is to remind that (Literal (IntLit i)) can be safely type-checkered as Expr, just like what a "naive" AST definition does.

a real world example

As we have known a little bit of recursion-schemes, its definition and the basic idea, we move on to a real world example: hindley-milner-type-check[3]

As its name suggests, hindley-milner-type-check is a library that provide HM-based type inference on an arbitray language based on lambda calculus. It brings a language using recursion-schemes style, and custom language should be able to map onto it in order to use its type-infer facility. Today, we mainly focus on its language definition and some folding cases of it.

The language is defined as follows:

import Data.Fix
import Data.Eq.Deriving
import Data.Ord.Deriving
import Text.Show.Deriving

-- | Term functor. The arguments are
-- loc for source code locations
-- v for variables
-- r for recursion

data TermF prim loc v r
    = Var loc v                       -- ^ Variables.
    | Prim loc prim                   -- ^ Primitives.
    | App loc r r                     -- ^ Applications.
    | Lam loc v r                     -- ^ Abstractions.
    | Let loc (Bind loc v r) r        -- ^ Let bindings.
    | LetRec loc [Bind loc v r] r     -- ^ Recursive  let bindings
    | AssertType loc r (Type loc v)   -- ^ Assert type.
    | Case loc r [CaseAlt loc v r]    -- ^ case alternatives
    | Constr loc v                    -- ^ constructor with tag
    | Bottom loc                      -- ^ value of any type that means failed program.
    deriving (Show, Eq, Functor, Foldable, Traversable, Data)

-- | Case alternatives
data CaseAlt loc v a = CaseAlt
  { caseAlt'loc   :: loc
  -- ^ source code location
  , caseAlt'tag   :: v
  -- ^ tag of the constructor
  , caseAlt'args  :: [(loc, v)]
  -- ^ arguments of the pattern matching
  , caseAlt'rhs   :: a
  -- ^ right-hand side of the case-alternative
  }
  deriving (Show, Eq, Functor, Foldable, Traversable, Data)

-- | Local variable definition.
--
-- > let lhs = rhs in ...
data Bind loc var a = Bind
  { bind'loc :: loc             -- ^ Source code location
  , bind'lhs :: var             -- ^ Variable name
  , bind'rhs :: a               -- ^ Definition (right-hand side)
  } deriving (Show, Eq, Functor, Foldable, Traversable, Data)

$(deriveShow1 ''TermF)
$(deriveEq1   ''TermF)
$(deriveOrd1  ''TermF)
$(deriveShow1 ''Bind)
$(deriveEq1   ''Bind)
$(deriveOrd1  ''Bind)
$(deriveShow1 ''CaseAlt)
$(deriveEq1   ''CaseAlt)
$(deriveOrd1  ''CaseAlt)

Take a look at the TermF definition, it has three variables: loc for source code locations, v for variables, r for recursion.

loc is not important in our study, so we just ignore it, v is the variable we should registered into the AST and it just corresponds to the marginal cases we will discuss later, r is the recursion variable, and will be transformed into marginal cases during folding. Note that CaseAlt and Bind also have r as recursion variable in the definition of TermF, this is also a demonstration of how to define mutually recursive data type using recursion-schemes

Also the definition of Term based on TermF using fixpoint:

-- | The type of terms.
newtype Term prim loc v = Term { unTerm :: Fix (TermF prim loc v) }
  deriving (Show, Eq, Data)

simple folding example

Now let's take a look at a simple folding example using recursion-schemes:

-- | Get free variables of the term.
freeVars :: Ord v => Term prim loc v -> Set v
freeVars = foldFix go . unTerm
  where
    go = \case
      Var    _ v          -> S.singleton v
      Prim   _ _          -> mempty
      App    _ a b        -> mappend a b
      Lam    _ v a        -> S.delete v a
      Let    _ bind body  -> let lhs = S.singleton $ bind'lhs bind
                             in  mappend (bind'rhs bind)
                                         (body `S.difference` lhs)
      LetRec _ binds body -> let 
                              lhs = S.fromList $ fmap bind'lhs binds
                             in  
                             (mappend (freeBinds binds) body) 
                             `S.difference` lhs
      AssertType _ a _    -> a
      Case _ e alts       -> mappend e (foldMap freeVarAlts alts)
      Constr _ _          -> mempty
      Bottom _            -> mempty

    freeBinds = foldMap bind'rhs

    freeVarAlts CaseAlt{..} = 
      caseAlt'rhs `S.difference` 
      (S.fromList $ fmap snd caseAlt'args)

Again, let's firstly focus on the two marginal cases: Var loc v and Prim loc prim.

The Var loc v case is simple, since v IS a free variable, we have

Var _ v -> S.singleton v

The Prim loc prim is even simpler, it has no variables inside

Prim _ _ -> mempty

Then, let's take a look at three non-marginal cases as an example(other cases follows the same rule): App loc r r, Lam loc v r and Let loc (Bind loc v r) r

The App loc r r case is simple, since application of r1 to r2 just results in free variables of r1 and r2 to merge. And remember that under foldFix, the r in App loc r r just reduce to the result of marginal cases(i.e. Set v), so that an mappend function is enough

App _ a b -> mappend a b

The Lam loc v r looks a little bit complex at first glance, since it has both v and r inside. But, remind that r will be reduced into Set v under foldFix, we can easily see we just need to do some set operation on r

Lam _ v a -> S.delete v a

We remove v from set a since v is binded.

The Let loc (Bind loc v r) r case is the most complex one, since it also refer to another datatype Bind. Remind the definition of Bind:

data Bind loc var a = Bind
  { bind'loc :: loc             -- ^ Source code location
  , bind'lhs :: var             -- ^ Variable name
  , bind'rhs :: a               -- ^ Definition (right-hand side)
  } deriving (Show, Eq, Functor, Foldable, Traversable, Data)

Looks complex? We just need to apply the strategy we used to analyze Term onto Bind.

So we use bind'lhs bind to fetch variables and use bind'rhs bind to fetch the Set of variables, that's it! The resulting function looks like this:

Let _ bind body  -> 
  let 
    lhs = S.singleton $ bind'lhs bind
  in  
    mappend (bind'rhs bind) (body `S.difference` lhs)

reference

[1] series of posts: https://blog.sumtypeofway.com/posts/introduction-to-recursion-schemes.html

[2] data-fix package: https://hackage.haskell.org/package/data-fix-0.3.3/docs/Data-Fix.html

[3] type-inference project: https://github.com/anton-k/hindley-milner-type-check

CC BY-SA 4.0 Septimia Zenobia. Last modified: November 16, 2024. Website built with Franklin.jl and the Julia programming language.