
Hello Café *Issue: * I'm having trouble using the AD package presumably due to a wrong use of types (Sorry I know it's a bit vague) Any help or pointer on how to solve this would be greatly appreciated. Thanks in advance *Background: * I've implemented a symbolic expression tree [full code here: https://pastebin.com/FDpSFuRM] -- | Expression Tree data ExprTree a = Const a -- ^ number | ParamNode ParamName -- ^ parameter | VarNode VarName -- ^ variable | UnaryNode MonoOp (ExprTree a) -- ^ operator of arity 1 | BinaryNode DualOp (ExprTree a) (ExprTree a) -- ^ operator of arity 2 | CondNode (Cond a) (ExprTree a) (ExprTree a) -- ^ conditional node deriving (Eq,Show, Generic) An evaluation function on it -- | evaluates an Expression Tree on its Context (Map ParamName a, Map VarName a) evaluate :: (Eq a, Ord a, Floating a) => ExprTree a -> Context a -> a And a few instances on it: instance Num a => Default (ExprTree a) where ... instance Num a => Num (ExprTree a) where ... instance Fractional a => Fractional (ExprTree a) where ... instance Floating a => Floating (ExprTree a) where ... instance (Arbitrary a) => Arbitrary (ExprTree a) where ... This allows me to easily create (using derivation rules) the derivative(s) of such a tree with respect to its variable(s): diff :: (Eq a, Floating a) => ExprTree a -> VarName -> ExprTree a grad :: (Eq a, Floating a) => ExprTree a -> Map VarName (ExprTree a) hessian :: (Eq a, Floating a) => ExprTree a -> Map VarName (Map VarName (ExprTree a)) So far, so good ... Now, to gain assurance in my implementation I wanted to check the derivatives against the Numeric.AD module so I create the two following -- | helper for AD usage exprTreeToListFun :: (RealFloat a) => ExprTree a -- ^ tree -> Map ParamName a -- ^ paramDict -> ([a] -> a) -- fun from var values to their evaluation exprTreeToListFun tree paramDict vals = res where res = evaluate tree (paramDict, varDict) varDict = Map.fromList $ zip (getVarNames tree) vals gradThroughAD :: RealFloat a => ExprTree a -> Context a -> Map VarName a gradThroughAD tree (paramDict, varDict) = res where varNames = Map.keys varDict varVals = Map.elems varDict gradList = AD.grad (exprTreeToListFun tree paramDict) varVals res = Map.fromList $ zip varNames gradList it unfortunately does not type check with message: • Couldn't match type ‘a’ with ‘Numeric.AD.Internal.Reverse.Reverse s a’ ‘a’ is a rigid type variable bound by the type signature for: gradThroughAD :: forall a. RealFloat a => ExprTree a -> Context a -> Map VarName a at src/ SymbolicExpression.hs:452:18 Expected type: [Numeric.AD.Internal.Reverse. Reverse s a] -> Numeric.AD.Internal.Reverse.Reverse s a Actual type: [a] -> a • In the first argument of ‘AD.grad’, namely ‘(exprTreeToListFun tree paramDict)’ In the expression: AD.grad (exprTreeToListFun tree paramDict) varVals In an equation for ‘gradList’: gradList = AD.grad (exprTreeToListFun tree paramDict) varVals • Relevant bindings include gradList :: [a] (bound at src/SymbolicExpression.hs:457:5) varVals :: [a] (bound at src/ SymbolicExpression.hs:456:5) varDict :: Map VarName a (bound at src/SymbolicExpression.hs:453:32) paramDict :: Map ParamName a (bound at src/SymbolicExpression.hs:453:21) tree :: ExprTree a (bound at src/ SymbolicExpression.hs:453:15) gradThroughAD :: ExprTree a -> Context a -> Map VarName a (bound at src/SymbolicExpression.hs:453:1) I was a bit surprised (I guess naively) by this, since to me, 'a' is polymorphic with RealFloat as type constraint and that is "supposed" to work with AD So anyway, I tried to modify as follow: {-# LANGUAGE Rank2Types #-} -- | helper for AD usage exprTreeToListFun :: (RealFloat a) => ExprTree a -- ^ tree -> Map ParamName a -- ^ paramDict -> *(RealFloat b => *[b] -> b) -- fun from var values to their evaluation exprTreeToListFun tree paramDict vals = res where res = *realToFrac* $ evaluate tree (paramDict, varDict) varDict = Map.fromList $ zip (getVarNames tree) *$ map realToFrac *vals This now typechecks and runs but going through AD returns me *all derivatives as zero* as if treating the variables (that passed through *realToFrac*) as constants, so not that much helpful either. Any idea how this can be solved ? Apologies if this is a question that already has an answer somewhere but poking around SO http://stackoverflow.com/search?q=%5Bhaskell%5D+ad+package It looks like I'm not the only one having similar issues and unfortunately none of the answers I found really helped me, if anything it confirms the same issue of null derivatives: http://stackoverflow.com/questions/36878083/ad-type-unification-error-with-c... Thanks again -- Frederic Cogny +33 7 83 12 61 69