module Data.Symbolic.Diff where import Data.Symbolic.TypedCode instance Num a => Num (Code a) where x + y = op'add `appC` x `appC` y x y = op'sub `appC` x `appC` y x * y = op'mul `appC` x `appC` y negate x = op'negate `appC` x fromInteger = integerC instance Fractional a => Fractional (Code a) where x / y = op'div `appC` x `appC` y recip x = op'recip `appC` x fromRational = rationalC instance Floating a => Floating (Code a) where pi = op'pi sin x = op'sin `appC` x cos x = op'cos `appC` x testf1 :: Num a => a testf1 = 1 + 2 testf1' = return (testf1 :: Code Int) testf1'' = showQC testf1' test1f x = let y = x * x in y + 1 test1 = test1f (2.0::Float) test1c = new'diffVar >>= \ (v::Var Float) -> return $ (test1f (var'exp v),v) test1r = test1c >>= \ (c,v) -> reflectDF v c test1cp = showQC test1r diffC :: (Floating a, Floating b) => Var b -> Code a -> Code a diffC v c | Just _ <- on'litC c = 0 diffC v c | Just ev <- on'varC v c = either (const 1) (const 0) ev diffC v c | Just (x,y) <- on'2opC op'add c = (diffC v x) + (diffC v y) diffC v c | Just (x,y) <- on'2opC op'sub c = (diffC v x) (diffC v y) diffC v c | Just (x,y) <- on'2opC op'mul c = ((diffC v x) * y) + (x * (diffC v y)) diffC v c | Just (x,y) <- on'2opC op'div c = ((diffC v x) * y x * (diffC v y)) / (y*y) diffC v c | Just x <- on'1opC op'negate c = negate (diffC v x) diffC v c | Just x <- on'1opC op'recip c = negate (diffC v x) / (x*x) diffC v c | Just x <- on'1opC op'sin c = (diffC v x) * cos x diffC v c | Just x <- on'1opC op'cos c = negate ((diffC v x) * sin x) diffC v c = error $ "Cannot handle code: " ++ show c test1d = test1c >>= \ (c,v) -> reflectDF v $ diffC v c test1dp = showQC test1d simpleC :: Floating a => Var b -> Code a -> Code a simpleC v c | Just c' <- simpleCL v c = simpleC v c' simpleC v c = c simpleCL :: Floating a => Var b -> Code a -> Maybe (Code a) simpleCL v c | Just _ <- on'litC c = Nothing simpleCL v c | Just _ <- on'varC v c = Nothing simpleCL v c | Just (x,y) <- on'2opC op'add c = simple'recur op'add sadd v x y where sadd x y | Just 0 <- on'litRationalC x = Just y sadd x y | Just 0 <- on'litRationalC y = Just x sadd x y | (Just x, Just y) <- (on'litRationalC x, on'litRationalC y) = Just (fromRational $ x + y) sadd x y = Nothing simpleCL v c | Just (x,y) <- on'2opC op'sub c = simple'recur op'sub ssub v x y where ssub x y | Just 0 <- on'litRationalC y = Just x ssub x y | (Just x, Just y) <- (on'litRationalC x, on'litRationalC y) = Just (fromRational $ x y) ssub x y = Nothing simpleCL v c | Just (x,y) <- on'2opC op'mul c = simple'recur op'mul smul v x y where smul x y | Just 0 <- on'litRationalC x = Just (fromRational 0) smul x y | Just 0 <- on'litRationalC y = Just (fromRational 0) smul x y | Just 1 <- on'litRationalC x = Just y smul x y | Just 1 <- on'litRationalC y = Just x smul x y | (Just x, Just y) <- (on'litRationalC x, on'litRationalC y) = Just (fromRational $ x * y) smul x y = Nothing simpleCL v c | Just (x,y) <- on'2opC op'div c = simple'recur op'div sdiv v x y where sdiv x y | Just 0 <- on'litRationalC x = Just (fromRational 0) sdiv x y = Nothing simpleCL v c | Just x <- on'1opC op'negate c = simple'recur1 op'negate sneg v x where sneg x | Just 0 <- on'litRationalC x = Just (fromRational 0) sneg x = Nothing simpleCL v c = Nothing simple'recur op fn v x y = case (simpleCL v x, simpleCL v y) of (Nothing,Nothing) -> fn x y (Just x,Nothing) -> Just (op `appC` x `appC` y) (Nothing,Just y) -> Just (op `appC` x `appC` y) (Just x,Just y) -> Just (op `appC` x `appC` y) simple'recur1 op fn v x = case simpleCL v x of Nothing -> fn x Just x -> Just (op `appC` x) test1ds = test1c >>= \ (c,v) -> reflectDF v $ simpleC v $ diffC v c test1dsp = showQC test1ds diff_fn :: Floating b => (forall a. Floating a => a -> a) -> QCode (b -> b) diff_fn f = do v <- new'diffVar let body = f (var'exp v) reflectDF v . simpleC v . diffC v $ body show_fn :: (forall a. Floating a => a -> a) -> IO () show_fn f = showQC ( do v <- new'diffVar reflectDF v (f (var'exp v))) test2f x = foldl (\z c -> x*z + c) 0 [1,2,3] test2n = test2f (4::Float) test2s = show_fn test2f test2ds = showQC (diff_fn test2f) test11f x = 2*x + 3*x test11ds = showQC (diff_fn test11f) test5f x = sin (5*x + pi/2) + cos(1 / x) test5n = test5f (pi::Float) test5ds = showQC (diff_fn test5f) test3f x y = (x*y + (5*x*x)) / y test4x y = diff_fn (\x -> test3f x (fromIntegral y)) test4y x = diff_fn (test3f (fromInteger x)) test4xds = showQC (test4x 1) test4yds = showQC (test4y 5)