{-# LANGUAGE DeriveFunctor #-}

{-|

Module      : MyFreeMonad
Description : Implement a Free Monad using a DSL as an example.
Copyright   : © Frank Jung, 2023
License     : GPL-3

The code was originally provided by ChatGPT. It had to be modified to
compile. I added tests.

-}

module MyFreeMonad ( ArithM
                   , ArithF (..)
                   , addM
                   , subM
                   , mulM
                   , divM
                   , evalArith
                   , example
                   , example'
                   ) where

import           Control.Monad.Free (Free (..), liftF)

-- | Arithmetic functor.
data ArithF x = Add Int x | Sub Int x | Mul Int x | Div Int x deriving (Int -> ArithF x -> ShowS
[ArithF x] -> ShowS
ArithF x -> String
(Int -> ArithF x -> ShowS)
-> (ArithF x -> String) -> ([ArithF x] -> ShowS) -> Show (ArithF x)
forall x. Show x => Int -> ArithF x -> ShowS
forall x. Show x => [ArithF x] -> ShowS
forall x. Show x => ArithF x -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall x. Show x => Int -> ArithF x -> ShowS
showsPrec :: Int -> ArithF x -> ShowS
$cshow :: forall x. Show x => ArithF x -> String
show :: ArithF x -> String
$cshowList :: forall x. Show x => [ArithF x] -> ShowS
showList :: [ArithF x] -> ShowS
Show, (forall a b. (a -> b) -> ArithF a -> ArithF b)
-> (forall a b. a -> ArithF b -> ArithF a) -> Functor ArithF
forall a b. a -> ArithF b -> ArithF a
forall a b. (a -> b) -> ArithF a -> ArithF b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> ArithF a -> ArithF b
fmap :: forall a b. (a -> b) -> ArithF a -> ArithF b
$c<$ :: forall a b. a -> ArithF b -> ArithF a
<$ :: forall a b. a -> ArithF b -> ArithF a
Functor)

-- | Arithmetic free monad.
type ArithM = Free ArithF

-- | Given Arithmetic free monad, return its value.
evalArith :: Free ArithF Int -> Int
evalArith :: Free ArithF Int -> Int
evalArith (Free (Add Int
x Free ArithF Int
n)) = Free ArithF Int -> Int
evalArith Free ArithF Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
x
evalArith (Free (Sub Int
x Free ArithF Int
n)) = Free ArithF Int -> Int
evalArith Free ArithF Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
x
evalArith (Free (Mul Int
x Free ArithF Int
n)) = Free ArithF Int -> Int
evalArith Free ArithF Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
x
evalArith (Free (Div Int
x Free ArithF Int
n)) = Free ArithF Int -> Int
evalArith Free ArithF Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
x
evalArith (Pure Int
x)         = Int
x

addM :: Int -> ArithM ()
addM :: Int -> ArithM ()
addM Int
x = ArithF () -> ArithM ()
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF (Int -> () -> ArithF ()
forall x. Int -> x -> ArithF x
Add Int
x ())

subM :: Int -> ArithM ()
subM :: Int -> ArithM ()
subM Int
x = ArithF () -> ArithM ()
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF (Int -> () -> ArithF ()
forall x. Int -> x -> ArithF x
Sub Int
x ())

mulM :: Int -> ArithM ()
mulM :: Int -> ArithM ()
mulM Int
x = ArithF () -> ArithM ()
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF (Int -> () -> ArithF ()
forall x. Int -> x -> ArithF x
Mul Int
x ())

divM :: Int -> ArithM ()
divM :: Int -> ArithM ()
divM Int
x = ArithF () -> ArithM ()
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF (Int -> () -> ArithF ()
forall x. Int -> x -> ArithF x
Div Int
x ())

-- | A simple example on how to use DSL:
--
-- @example 0 = ((((0+10)*2)-10)/2) == 5@
--
-- Get back the integer value with:
--
-- @evalArith (example 0)@
example :: Int -> ArithM Int
example :: Int -> Free ArithF Int
example Int
n =
  Int -> ArithM ()
divM Int
2
  ArithM () -> ArithM () -> ArithM ()
forall a b. Free ArithF a -> Free ArithF b -> Free ArithF b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> ArithM ()
subM Int
10
  ArithM () -> ArithM () -> ArithM ()
forall a b. Free ArithF a -> Free ArithF b -> Free ArithF b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> ArithM ()
mulM Int
2
  ArithM () -> ArithM () -> ArithM ()
forall a b. Free ArithF a -> Free ArithF b -> Free ArithF b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> ArithM ()
addM Int
10
  ArithM () -> Free ArithF Int -> Free ArithF Int
forall a b. Free ArithF a -> Free ArithF b -> Free ArithF b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> Free ArithF Int
forall a. a -> Free ArithF a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n

-- | Another example on how to use this DSL:
example' :: Int -> ArithM Int
example' :: Int -> Free ArithF Int
example' Int
n =
  Int -> ArithM ()
divM Int
2
  ArithM () -> ArithM () -> ArithM ()
forall a b. Free ArithF a -> Free ArithF b -> Free ArithF b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> ArithM ()
mulM Int
2
  ArithM () -> Free ArithF Int -> Free ArithF Int
forall a b. Free ArithF a -> Free ArithF b -> Free ArithF b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> Free ArithF Int
forall a. a -> Free ArithF a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n