Hello Café

Issue: 
I'm having trouble using the AD package presumably due to a wrong use of types (Sorry I know it's a bit vague)
Any help or pointer on how to solve this would be greatly appreciated.
Thanks in advance

Background: 
I've implemented a symbolic expression tree [full code here: https://pastebin.com/FDpSFuRM]
-- | Expression Tree
data ExprTree a = Const a                             -- ^ number
                | ParamNode ParamName                 -- ^ parameter
                | VarNode VarName                     -- ^ variable
                | UnaryNode MonoOp (ExprTree a)       -- ^ operator of arity 1
                | BinaryNode DualOp (ExprTree a) (ExprTree a) -- ^ operator of arity 2
                | CondNode (Cond a) (ExprTree a) (ExprTree a)     -- ^ conditional node
                deriving (Eq,Show, Generic)

An evaluation function on it
-- |  evaluates an Expression Tree on its Context (Map ParamName a, Map VarName a)
evaluate :: (Eq a, Ord a, Floating a) => ExprTree a -> Context a -> a

And a few instances on it:
instance Num a => Default (ExprTree a) where ...
instance Num a => Num (ExprTree a) where ...
instance Fractional a => Fractional (ExprTree a) where ...
instance Floating a => Floating (ExprTree a) where ...
instance (Arbitrary a) => Arbitrary (ExprTree a) where ...

This allows me to easily create (using derivation rules) the derivative(s) of such a tree with respect to its variable(s):
diff :: (Eq a, Floating a) => ExprTree a -> VarName -> ExprTree a
grad :: (Eq a, Floating a) => ExprTree a -> Map VarName (ExprTree a)
hessian :: (Eq a, Floating a) => ExprTree a -> Map VarName (Map VarName (ExprTree a))

So far, so good ...

Now, to gain assurance in my implementation I wanted to check the derivatives against the Numeric.AD module
so I create the two following
-- | helper for AD usage
exprTreeToListFun :: (RealFloat a)
                  => ExprTree a      -- ^ tree
                  -> Map ParamName a -- ^ paramDict
                  -> ([a] -> a)      -- fun from var values to their evaluation
exprTreeToListFun tree paramDict vals = res
  where
    res            = evaluate tree (paramDict, varDict)
    varDict        = Map.fromList $ zip (getVarNames tree) vals


gradThroughAD :: RealFloat a => ExprTree a -> Context a -> Map VarName a
gradThroughAD tree (paramDict, varDict) = res
  where
    varNames = Map.keys varDict
    varVals  = Map.elems varDict
    gradList = AD.grad (exprTreeToListFun tree paramDict) varVals
    res      = Map.fromList $ zip varNames gradList


it unfortunately does not type check with message:
• Couldn't match type a with Numeric.AD.Internal.Reverse.Reverse s a a is a rigid type variable bound by the type signature for: gradThroughAD :: forall a. RealFloat a => ExprTree a -> Context a -> Map VarName a at src/SymbolicExpression.hs:452:18 Expected type: [Numeric.AD.Internal.Reverse.Reverse s a] -> Numeric.AD.Internal.Reverse.Reverse s a Actual type: [a] -> a • In the first argument of AD.grad, namely (exprTreeToListFun tree paramDict) In the expression: AD.grad (exprTreeToListFun tree paramDict) varVals In an equation for gradList’: gradList = AD.grad (exprTreeToListFun tree paramDict) varVals • Relevant bindings include gradList :: [a] (bound at src/SymbolicExpression.hs:457:5) varVals :: [a] (bound at src/SymbolicExpression.hs:456:5) varDict :: Map VarName a (bound at src/SymbolicExpression.hs:453:32) paramDict :: Map ParamName a (bound at src/SymbolicExpression.hs:453:21) tree :: ExprTree a (bound at src/SymbolicExpression.hs:453:15) gradThroughAD :: ExprTree a -> Context a -> Map VarName a (bound at src/SymbolicExpression.hs:453:1)


I was a bit surprised (I guess naively) by this, since to me, 'a' is polymorphic with RealFloat as type constraint and that is "supposed" to work with AD
 
So anyway, I tried to modify as follow:
{-# LANGUAGE Rank2Types    #-}

-- | helper for AD usage
exprTreeToListFun :: (RealFloat a)
                  => ExprTree a      -- ^ tree
                  -> Map ParamName a -- ^ paramDict
                  -> (RealFloat b => [b] -> b)      -- fun from var values to their evaluation
exprTreeToListFun tree paramDict vals = res
  where
    res            = realToFrac $ evaluate tree (paramDict, varDict)
    varDict        = Map.fromList $ zip (getVarNames tree) $ map realToFrac vals
 
This now typechecks and runs but going through AD returns me all derivatives as zero as if treating the variables (that passed through realToFrac) as constants, so not that much helpful either.

Any idea how this can be solved ?

Apologies if this is a question that already has an answer somewhere but poking around SO It looks like I'm not the only one having similar issues and unfortunately none of the answers I found really helped me, if anything it confirms the same issue of null derivatives: http://stackoverflow.com/questions/36878083/ad-type-unification-error-with-constrained-type-vector

Thanks again

--
Frederic Cogny
+33 7 83 12 61 69