Hello Café
Issue:
I'm having trouble using the AD package presumably due to a wrong use of types (Sorry I know it's a bit vague)
Any help or pointer on how to solve this would be greatly appreciated.
Thanks in advance
Background:
-- | Expression Tree
data ExprTree a = Const a -- ^ number
| ParamNode ParamName -- ^ parameter
| VarNode VarName -- ^ variable
| UnaryNode MonoOp (ExprTree a) -- ^ operator of arity 1
| BinaryNode DualOp (ExprTree a) (ExprTree a) -- ^ operator of arity 2
| CondNode (Cond a) (ExprTree a) (ExprTree a) -- ^ conditional node
deriving (Eq,Show, Generic)
An evaluation function on it
-- | evaluates an Expression Tree on its Context (Map ParamName a, Map VarName a)
evaluate :: (Eq a, Ord a, Floating a) => ExprTree a -> Context a -> a
And a few instances on it:
instance Num a => Default (ExprTree a) where ...
instance Num a => Num (ExprTree a) where ...
instance Fractional a => Fractional (ExprTree a) where ...
instance Floating a => Floating (ExprTree a) where ...
instance (Arbitrary a) => Arbitrary (ExprTree a) where ...
This allows me to easily create (using derivation rules) the derivative(s) of such a tree with respect to its variable(s):
diff :: (Eq a, Floating a) => ExprTree a -> VarName -> ExprTree a
grad :: (Eq a, Floating a) => ExprTree a -> Map VarName (ExprTree a)
hessian :: (Eq a, Floating a) => ExprTree a -> Map VarName (Map VarName (ExprTree a))
So far, so good ...
Now, to gain assurance in my implementation I wanted to check the derivatives against the Numeric.AD module
so I create the two following
-- | helper for AD usage
exprTreeToListFun :: (RealFloat a)
=> ExprTree a -- ^ tree
-> Map ParamName a -- ^ paramDict
-> ([a] -> a) -- fun from var values to their evaluation
exprTreeToListFun tree paramDict vals = res
where
res = evaluate tree (paramDict, varDict)
varDict = Map.fromList $ zip (getVarNames tree) vals
gradThroughAD :: RealFloat a => ExprTree a -> Context a -> Map VarName a
gradThroughAD tree (paramDict, varDict) = res
where
varNames = Map.keys varDict
varVals = Map.elems varDict
gradList = AD.grad (exprTreeToListFun tree paramDict) varVals
res = Map.fromList $ zip varNames gradList
it unfortunately does not type check with message:
• Couldn't match type ‘a’
with ‘Numeric.AD.Internal.Reverse.Reverse s a’
‘a’ is a rigid type variable bound by
the type signature for:
gradThroughAD :: forall a.
RealFloat a =>
ExprTree a -> Context a -> Map VarName a
at src/SymbolicExpression.hs:452:18
Expected type: [Numeric.AD.Internal.Reverse.Reverse s a]
-> Numeric.AD.Internal.Reverse.Reverse s a
Actual type: [a] -> a
• In the first argument of ‘AD.grad’, namely
‘(exprTreeToListFun tree paramDict)’
In the expression:
AD.grad (exprTreeToListFun tree paramDict) varVals
In an equation for ‘gradList’:
gradList = AD.grad (exprTreeToListFun tree paramDict) varVals
• Relevant bindings include
gradList :: [a] (bound at src/SymbolicExpression.hs:457:5)
varVals :: [a] (bound at src/SymbolicExpression.hs:456:5)
varDict :: Map VarName a
(bound at src/SymbolicExpression.hs:453:32)
paramDict :: Map ParamName a
(bound at src/SymbolicExpression.hs:453:21)
tree :: ExprTree a (bound at src/SymbolicExpression.hs:453:15)
gradThroughAD :: ExprTree a -> Context a -> Map VarName a
(bound at src/SymbolicExpression.hs:453:1)
I was a bit surprised (I guess naively) by this, since to me, 'a' is polymorphic with RealFloat as type constraint and that is "supposed" to work with AD
So anyway, I tried to modify as follow:
{-# LANGUAGE Rank2Types #-}
-- | helper for AD usage
exprTreeToListFun :: (RealFloat a)
=> ExprTree a -- ^ tree
-> Map ParamName a -- ^ paramDict
-> (RealFloat b => [b] -> b) -- fun from var values to their evaluation
exprTreeToListFun tree paramDict vals = res
where
res = realToFrac $ evaluate tree (paramDict, varDict)
varDict = Map.fromList $ zip (getVarNames tree) $ map realToFrac vals
This now typechecks and runs but going through AD returns me all derivatives as zero as if treating the variables (that passed through realToFrac) as constants, so not that much helpful either.
Any idea how this can be solved ?
Thanks again