{-# OPTIONS_GHC -Wno-unused-local-binds -Wno-orphans -Wno-unused-top-binds -Wno-type-defaults #-} {-# LANGUAGE DeriveFunctor, FlexibleContexts, PatternSynonyms, ViewPatterns , NoOverloadedStrings #-} module PWhile.InferEt ( Result (..) , FilterResult (..) , Log , SolveM , run, runAny , ect, evt , showResult ) where import Control.Applicative import Control.Monad.Except import qualified Control.Monad.State as St import Data.Maybe (listToMaybe) import qualified Data.Map.Strict as M import Data.Tree (Forest) import qualified Data.Set as S import qualified GUBS.Polynomial as P import qualified ListT as L import qualified Text.PrettyPrint.ANSI.Leijen as PP import Control.Monad.Trace -- import Debug.Trace as D import PWhile.InvariantSolver import PWhile.InvariantSolver.Naive import Data.PWhile.BoolExpression import qualified Data.PWhile.BoolExpression as B import qualified Data.PWhile.Expression as E import qualified Data.PWhile.CostExpression as C import Data.PWhile.Program import PWhile.Util --- * Solver Monad --------------------------------------------------------------------------------------------------- type CExp = C.CExp Rational type MemoT = St.StateT (M.Map (Knd,C,CExp) CExp) runMemoT :: Monad m => MemoT m v -> m v runMemoT = flip St.evalStateT M.empty memoized :: (Knd -> C -> CExp -> SolveM CExp) -> Knd -> C -> CExp -> SolveM CExp memoized f k c e = do memo <- liftMemoT St.get case M.lookup (k,c,e) memo of Nothing -> do v <- f k c e liftMemoT $ St.modify (M.insert (k,c,e) v) return v Just v -> return v newtype SolveM a = SolveM { runSolveM_ :: ExceptT String (TraceT String (L.ListT (MemoT IO))) a } deriving (Applicative, Functor, Monad, MonadError String, MonadTrace String, MonadIO) -- prvoides a 'Stream' of possible failing computations with logging runSolveM' :: SolveM a -> L.ListT (MemoT IO) (Either String a, Forest String) runSolveM' = runTraceT . runExceptT . runSolveM_ runSolveM :: SolveM a -> IO [(Either String a, Forest String)] runSolveM = runMemoT . L.toList . runTraceT . runExceptT . runSolveM_ -- * Alternative instance Alternative SolveM where empty = SolveM $ ExceptT $ TraceT empty ma <|> mb = do let a = runExceptT (runSolveM_ ma) b = runExceptT (runSolveM_ mb) SolveM $ ExceptT $ a <|> b liftMemoT :: MemoT IO a -> SolveM a liftMemoT = SolveM . lift . lift . lift liftListT :: L.ListT (MemoT IO) a -> SolveM a liftListT = SolveM . lift . lift alternative :: Foldable f => f a -> SolveM a alternative = liftListT . L.fromFoldable first :: Monad m => (a -> Bool) -> L.ListT m a -> m [a] first p = go where go m = do aM <- L.uncons m case aM of Nothing -> return [] Just (a,l) -> if p a then return [a] else go l -- ** Logging logMsgIdM :: PP.Pretty a => SolveM a -> SolveM a logMsgIdM ma = ma >>= \a -> logMsg (renderPretty $ PP.pretty a) >> return a logBlkIdM :: PP.Pretty a => String -> SolveM a -> SolveM a logBlkIdM s = logBlk s . logMsgIdM debugM :: (Applicative f, PP.Pretty e) => String -> e -> f e debugM = debugMsgA "InferEt" debug :: (PP.Pretty e) => String -> e -> e debug = debugMsg "InferEt" -- ** Result type Log = Forest String data Result a = Failure String Log | Success a Log instance PP.Pretty a => PP.Pretty (Result a) where pretty (Failure m l) = renderLog l PP.<$$> PP.red (PP.text "[Failure]" PP.<+> PP.text m) pretty (Success f l) = renderLog l PP.<$$> PP.green (PP.text "[Success]" PP.<+> PP.pretty f) showResult :: PP.Pretty a => Result a -> String showResult = show . PP.pretty -- instance PP.Pretty [Result] where -- pretty rs = PP.vcat -- [ PP.pretty "***" PP. PP.pretty r PP. PP.pretty "<<<" PP. PP.text "" | r <- rs] data FilterResult = Any -- return first successful result | Take Int -- take first n results | Pick Int -- pick ith result | All -- take all results deriving (Show, Eq) run :: FilterResult -> SolveM a -> IO [Result a] run t = fmap (map toResult) . runMemoT . select t . runSolveM' where select Any = first success where success (fst -> Right{}) = True success _ = False select (Take i) = L.toList . L.take i select (Pick i) = L.toList . L.drop (pred i) . L.take i select All = L.toList toResult (Left err, l) = Failure err l toResult (Right a, l) = Success a l runAny :: SolveM a -> IO (Maybe (Result a)) runAny m = listToMaybe <$> run Any m -- cost functions ---------------------------------------------------------------------- varGNorms :: (Eq c, Num c) => C.CExp c -> [C.GNorm c] varGNorms = filter (not . C.isConstGN) . C.gNorms -- estimating expectations ---------------------------------------------------------------------- discreteExpectation :: [(E.Exp, e)] -> (e -> SolveM CExp) -> SolveM CExp discreteExpectation [(_,e)] f = f e discreteExpectation ls f = C.divBy <$> (C.sum <$> sequence [C.scale w <$> f e | (w,e) <- ls']) <*> pure (sum [e | (e,_) <- ls']) where ls' = [ (fmap fromIntegral w, e) | (w,e) <- ls ] boundedSum :: (E.Exp -> CExp) -> (E.Exp, E.Exp) -> SolveM CExp boundedSum f (l,o) = logBlkIdM msg $ bs fi where i = E.Var "@i" vi = E.variable i fi = f vi msg = "[sum] " ++ renderPretty (PP.text "sum" PP.<> PP.tupled [ PP.pretty l, PP.pretty o, PP.pretty fi]) bs g | g == C.zero = return C.zero bs (C.Plus g1 g2) = C.plus <$> bs g1 <*> bs g2 bs (C.Cond g g1 g2) | g2 /= C.zero = C.plus <$> bs (C.guarded g g1) <*> bs (C.guarded (neg g) g2) bs g = do logMsg2 "norms" ns (ui,s,dir) <- alternative [(uiDown, o, "Down"), (uiUp, l, "Up")] logMsg2 ("SumFn["++dir++"]") ui f' <- debugM "result" =<< solveInvariant ui logMsg2 "Solution" f' return (E.substitute i s f') where uiDown = UpperInvariant { inv = o .>= vi , cnd = vi .>= l , cost = g , step = (E.substitute i (vi - 1) . C.fromNorm) `map` ns , cont = C.zero , limt = ns } uiUp = UpperInvariant { inv = vi .>= l , cnd = o .>= vi , cost = g , step = (E.substitute i (vi + 1) . C.fromNorm) `map` ns , cont = C.zero , limt = ns } ns = [ C.norm (e .>= 0) e `C.mulN` n | C.GNorm b _ n <- varGNorms g , e1 :>=: e2 <- S.toList (literals b) , let e = e1 - e2 + 1] expectation :: E.Dist E.Exp -> (E.Exp -> CExp) -> SolveM CExp expectation (E.Discrete ls) f = discreteExpectation ls (return . f) -- expectation (E.Rand (E.Constant n)) f = -- discreteExpectation [(1,E.constant i) | i <- [0..n-1]] (return . f) expectation (E.Rand n) f = logBlkIdM ("[expectation] " ++ "rand(" ++ show (PP.pretty n) ++ ")") $ do let (vi,vj) = (E.Var "@i", E.Var "@j") (fi,fj) = (f (E.variable vi), f (E.variable vj)) logMsg2 "f" fi C.guarded (n .> 0) <$> case n of _ | fi == fj -> return fi -- TODO! -- | Just g <- linFn fi vi (C.scale (1 / 2) (C.E n)) -> return g E.Constant n' -> discreteExpectation [(1,E.constant i) | i <- [0..n'-1]] (return . f) _ -> C.divBy <$> boundedSum f (0,n-1) <*> pure (fmap fromIntegral n) where noccur v e = v `notElem` E.variables e -- TODO: add linFn -- linFn :: CExp -> E.Var -> CExp -> Maybe (CExp) -- linFn (C.guardedExpressions -> [(Top,p C.:% q)]) v avg -- | not (v `S.member` C.variables (C.toCost q)) = -- (*) <$> (sum <$> sequence [linMono m v avg | m <- P.toMonos p ]) -- <*> pure (C.toCost (1 C.:% q)) -- linFn _ _ _ = Nothing -- linMono (coeff,P.toPowers -> powers) v avg = C.scale coeff <$> lin powers where -- lin [] = return 1 -- lin (p:ps) -- | v `noccur` fst p = (*) (C.fromPowers [p]) <$> lin ps -- lin ((C.Norm e,1):ps) = (*) <$> linExp e v avg <*> con ps -- lin _ = Nothing -- con [] = return 1 -- con (p:ps) -- | v `noccur` fst p = (*) (C.fromPowers [p]) <$> con ps -- con _ = Nothing -- linExp :: E.Exp -> E.Var -> CExp -> Maybe (CExp) -- linExp e v avg -- | c >= 0 && d >= 0 -- && E.constant c * E.variable v + E.constant d == e = -- Just (C.scale (toRational c) avg + C.constant (toRational d)) -- | otherwise = Nothing -- where -- c = P.coefficientOf (P.fromPowers [(v,1)]) e -- d = P.coefficientOf (P.fromPowers []) e -- Expectation Transformer ---------------------------------------------------------------------- type NormTemplate = P.Polynomial C.Norm Int normTemplateEq :: NormTemplate -> NormTemplate -> Bool p1 `normTemplateEq` p2 = and $ P.coefficients $ P.zipCoefficientsWith (const False) (const False) (\ _ _ -> True) p1 p2 norm :: C.Norm -> NormTemplate norm (C.Norm _ (E.Constant _)) = P.coefficient 1 norm n = P.variable n normFromExp :: E.Exp -> NormTemplate normFromExp e = norm (C.Norm (e .>= 0) e) normFromBExp :: BExp -> NormTemplate normFromBExp b = sum [ normFromExp (e1 - e2 + 1) | e1 :>=: e2 <- S.toList (literals b) ] normFromCExp :: C.CExp Rational -> NormTemplate normFromCExp d = sum [ normFromBExp b + norm n | C.GNorm b k n <- varGNorms d , k > 0, not (C.isZeroN n) ] substituteM :: (C.Norm -> SolveM NormTemplate) -> NormTemplate -> SolveM NormTemplate substituteM s = P.fromPolynomialM s (pure . P.coefficient) template :: String -> SolveM NormTemplate -> SolveM NormTemplate template name = logBlkIdM (name++"-template") extractRanking :: C -> BExp -> BExp -> CExp -> CExp -> SolveM [C.Norm] extractRanking body i c g f = fmap toNormList $ do lin <- linearTemplate template "linear" (pure lin) <|> template "shift-avg" (refine shiftAvg lin) <|> template "shift-avg" (refine shiftAvg lin) <|> template "conditions" (refine conds lin) <|> template "shift-max" (refine shiftMax lin) <|> do let sq = lin * lin template "square" (pure sq) where toNormList p = filter (not . zeroN) [ C.prodN [ C.expN n k | (n,k) <- P.toPowers m ] | (_,m) <- P.toMonos p] where zeroN (C.Norm _ e) = e == 0 refine m r = do r' <- m r if r' `normTemplateEq` r then empty else pure r' linearTemplate = do let gNorm = normFromCExp (C.guarded c g) fNorm = normFromCExp (C.guarded (neg c) f) df <- substituteM delta fNorm pure (normFromExp 1 + normFromBExp i + normFromBExp c + fNorm + gNorm + sum [ normFromBExp c * df ]) shiftAvg = substituteM s where s n@(C.Norm _ e) = do evn <- etM Evt body (C.nm 1 n) pure $ normFromExp $ E.maxE $ e : [2 * e - fmap floor e' | e' <- lfs evn ] shiftMax = substituteM s where s n@(C.Norm _ e) = do evn <- etM Evt body (C.nm 1 n) pure $ normFromExp $ E.maxE $ e : [2 * e - e' | C.GNorm _ _ (C.Norm _ e') <- varGNorms evn ] conds = substituteM s where s n = do evn <- etM Evt body (C.nm 1 n) pure $ norm n + sum [ normFromBExp b | C.GNorm b _ _ <- varGNorms evn ] delta n@(C.Norm _ e) = do ns <- S.toList . C.norms <$> etM Evt body (C.nm 1 n) let d = E.maxE [ ne | C.Norm _ ne <- ns ] - e return (normFromExp d) lfs (C.Sup c1 c2) = lfs c1 ++ lfs c2 lfs (C.Cond _ c1 c2) = lfs c1 ++ lfs c2 lfs (C.Div d _) = lfs d lfs (C.Plus c1 c2) = [e1 + e2 | e1 <- lfs c1, e2 <- lfs c2] lfs (C.N k (C.Norm _ e)) = [E.constant k * fmap fromIntegral e] isConstantWrt :: CExp -> C -> Bool f `isConstantWrt` c = C.variables f `S.disjoint` vs where vs = S.fromList [ v | Ass v _ <- subPrograms c] isTickFree :: C -> Bool isTickFree = not . any isTick . subPrograms where isTick Tic{} = True isTick _ = False isSimple :: C -> Bool isSimple = not . any isWhile . subPrograms where isWhile While{} = True isWhile _ = False solveInvariant :: UpperInvariant -> SolveM CExp solveInvariant ui = maybe (throwError "no solution") return $ solve (NaiveSolver 1 2) =<< debugM "invariant" ui data Knd = Evt | Ect deriving (Eq, Ord, Show) ect, evt :: C -> CExp -> SolveM CExp ect c f = logBlkIdM ("[ect] " ++ show (PP.pretty f)) $ logMsg2 "Problem" c *> etM Ect c f evt c f = logBlkIdM ("[evt] " ++ show (PP.pretty f)) $ logMsg2 "Problem" c *> etM Evt c f etM :: Knd -> C -> CExp -> SolveM CExp etM = memoized et et :: Knd -> C -> CExp -> SolveM CExp et Ect c f | f == C.zero && isTickFree c = return C.zero et Evt c f | not (isSimple c) && f `isConstantWrt` c = return f et Evt c (C.Plus f1 f2) | not (isSimple c) = C.plus <$> evt c f1 <*> evt c f2 et _ Abort _ = return C.zero et _ Skip f = return f et Evt (Tic _) f = return f et Ect (Tic e) f = return $ C.ramp e `C.plus` f et _ (Ass v d) f = expectation d (\ e -> E.substitute v e f) et t (NonDet e1 e2) f = C.sup <$> etM t e1 f <*> etM t e2 f et t (Choice ls) f = discreteExpectation ls (\ c -> etM t c f) et t (Cond _ i b c1 c2) f = C.guarded i <$> (C.cond b <$> etM t c1 f <*> etM t c2 f) et Ect (Seq c1 c2) f | isSimple c1 = et Ect c1 =<< et Ect c2 f | otherwise = C.plus <$> ect c1 C.zero <*> (evt c1 =<< ect c2 f) et t (Seq c1 c2) f = etM t c1 =<< etM t c2 f et t (While _ i b c) f = logBlkIdM "[While.Step]" $ do g <- case t of {Evt -> return C.zero; Ect -> ect c C.zero} ns <- extractRanking c i b g f logMsg2 "Norms" ns hs <- traverse (evt c . C.fromNorm) ns let ui = UpperInvariant { inv = i , cnd = b , cost = g , step = hs , cont = f , limt = ns} logMsg2 "Invariant" ui h <- solveInvariant ui return (C.cond b h f) -- * Pretty instance PP.Pretty Knd where pretty = PP.text . show