Commit 8de9c4c9 authored by AVANZINI Martin's avatar AVANZINI Martin
Browse files

expectation of linear functions

parent 801f5172
This diff is collapsed.
......@@ -104,8 +104,8 @@ plus c1 c2 | c1 == zero = c2
| c2 == zero = c1
| otherwise = Plus c1 c2
ramp :: Num c => E.Exp -> CExp c
ramp e = N 1 (norm (e .>= 0) e)
ramp :: (Eq c, Num c) => E.Exp -> CExp c
ramp e = nm 1 (norm (e .>= 0) e)
sum :: Coeff c => [CExp c] -> CExp c
sum = foldl plus zero
......
......@@ -35,7 +35,7 @@ import qualified Data.PWhile.BoolExpression as B
import qualified Data.PWhile.Expression as E
import qualified Data.PWhile.CostExpression as C
import Data.PWhile.Program
import PWhile.Util
import qualified PWhile.Util as U
--- * Solver Monad ---------------------------------------------------------------------------------------------------
......@@ -44,6 +44,7 @@ type CExp = C.CExp Rational
type MemoT = St.StateT (M.Map (Knd,C,CExp) CExp)
runMemoT :: Monad m => MemoT m v -> m v
runMemoT = flip St.evalStateT M.empty
......@@ -99,17 +100,20 @@ first p = go where
-- ** Logging
logMsgIdM :: PP.Pretty a => SolveM a -> SolveM a
logMsgIdM ma = ma >>= \a -> logMsg (renderPretty $ PP.pretty a) >> return a
logResult :: PP.Pretty a => String -> SolveM a -> SolveM a
logResult n ma = ma >>= \a -> U.logMsg2 n a >> return a
logData :: PP.Pretty a => String -> a -> SolveM ()
logData = U.logMsg2
logBlkIdM :: PP.Pretty a => String -> SolveM a -> SolveM a
logBlkIdM s = logBlk s . logMsgIdM
logBlk :: PP.Pretty a => String -> SolveM a -> SolveM a
logBlk s m = U.logBlk s (m >>= \a -> U.logMsg a >> return a)
debugM :: (Applicative f, PP.Pretty e) => String -> e -> f e
debugM = debugMsgA "InferEt"
debugM = U.debugMsgA "InferEt"
debug :: (PP.Pretty e) => String -> e -> e
debug = debugMsg "InferEt"
debug = U.debugMsg "InferEt"
-- ** Result
......@@ -119,8 +123,8 @@ data Result a = Failure String Log
| Success a Log
instance PP.Pretty a => PP.Pretty (Result a) where
pretty (Failure m l) = renderLog l PP.<$$> PP.red (PP.text "[Failure]" PP.<+> PP.text m)
pretty (Success f l) = renderLog l PP.<$$> PP.green (PP.text "[Success]" PP.<+> PP.pretty f)
pretty (Failure m l) = U.renderLog l PP.<$$> PP.red (PP.text "[Failure]" PP.<+> PP.text m)
pretty (Success f l) = U.renderLog l PP.<$$> PP.green (PP.text "[Success]" PP.<+> PP.pretty f)
showResult :: PP.Pretty a => Result a -> String
......@@ -157,74 +161,124 @@ runAny m = listToMaybe <$> run Any m
varGNorms :: (Eq c, Num c) => C.CExp c -> [C.GNorm c]
varGNorms = filter (not . C.isConstGN) . C.gNorms
-- templates
----------------------------------------------------------------------
type NormTemplate = P.Polynomial C.Norm Int
templateToNorms :: NormTemplate -> [C.Norm]
templateToNorms p =
filter (not . zeroN) [ C.prodN [ C.expN n k | (n,k) <- P.toPowers m ]
| (_,m) <- P.toMonos p]
where zeroN (C.Norm _ e) = e == 0
normTemplateEq :: NormTemplate -> NormTemplate -> Bool
p1 `normTemplateEq` p2 = and $ P.coefficients $
P.zipCoefficientsWith (const False) (const False) (\ _ _ -> True) p1 p2
norm :: C.Norm -> NormTemplate
norm (C.Norm _ (E.Constant _)) = P.coefficient 1
norm n = P.variable n
normFromExp :: E.Exp -> NormTemplate
normFromExp e = norm (C.Norm (e .>= 0) e)
normFromBExp :: BExp -> NormTemplate
normFromBExp b = sum [ normFromExp (e1 - e2 + 1)
| e1 :>=: e2 <- S.toList (literals b) ]
normFromCExp :: C.CExp Rational -> NormTemplate
normFromCExp d = sum [ normFromBExp b + norm n
| C.GNorm b k n <- varGNorms d
, k > 0, not (C.isZeroN n) ]
substituteM :: (C.Norm -> SolveM NormTemplate) -> NormTemplate -> SolveM NormTemplate
substituteM s = P.fromPolynomialM s (pure . P.coefficient)
template :: String -> SolveM NormTemplate -> SolveM NormTemplate
template name = logBlk (name++"-template")
-- estimating expectations
----------------------------------------------------------------------
discreteExpectation :: [(E.Exp, e)] -> (e -> SolveM CExp) -> SolveM CExp
discreteExpectation [(_,e)] f = f e
discreteExpectation ls f =
finiteExpectation :: [(E.Exp, e)] -> (e -> SolveM CExp) -> SolveM CExp
finiteExpectation [(_,e)] f = f e
finiteExpectation ls f =
C.divBy <$> (C.sum <$> sequence [C.scale w <$> f e | (w,e) <- ls'])
<*> pure (sum [e | (e,_) <- ls'])
where ls' = [ (fmap fromIntegral w, e) | (w,e) <- ls ]
boundedSum :: (E.Exp -> CExp) -> (E.Exp, E.Exp) -> SolveM CExp
boundedSum f (l,o) = logBlkIdM msg $ bs fi where
i = E.Var "@i"
vi = E.variable i
fi = f vi
msg = "[sum] " ++ renderPretty (PP.text "sum" PP.<> PP.tupled [ PP.pretty l, PP.pretty o, PP.pretty fi])
bs g
| g == C.zero = return C.zero
bs (C.Plus g1 g2) = C.plus <$> bs g1 <*> bs g2
boundedSum :: (E.Var,CExp) -> (E.Exp, E.Exp) -> SolveM CExp
boundedSum (i,f) (l,o) = bs f
where
bs g | not (i `S.member` C.variables g) = return $ C.scale (o - l + 1) g
bs (C.Plus g1 g2) = C.plus <$> bs g1 <*> bs g2
bs (C.Cond g g1 g2)
| g2 /= C.zero = C.plus <$> bs (C.guarded g g1) <*> bs (C.guarded (neg g) g2)
bs g = do
logMsg2 "norms" ns
(ui,s,dir) <- alternative [(uiDown, o, "Down"), (uiUp, l, "Up")]
logMsg2 ("SumFn["++dir++"]") ui
f' <- debugM "result" =<< solveInvariant ui
logMsg2 "Solution" f'
return (E.substitute i s f')
where
| g2 /= C.zero = C.plus <$> bs (C.guarded g g1)
<*> bs (C.guarded (neg g) g2)
| not (i `S.member` B.variables g) = C.guarded g <$> bs g1
bs (C.N k (C.Norm g e))
| linearCase = pure $
C.N (k/2) (C.norm (g .&& df .>= 0) (df * E.substitute i df e))
where
df = o - l
linearCase = all (all lin . P.toPowers . snd) (P.toMonos e)
lin (vi,m) = vi /= i || m <= 1
bs g = logBlk "bounded-sum" $ do
logData "expression" (PP.text "sum"
PP.<> PP.braces (PP.pretty g
PP.<+> PP.text "|"
PP.<+> PP.pretty l
PP.<+> PP.text "<="
PP.<+> PP.pretty i
PP.<+> PP.text "<="
PP.<+> PP.pretty o))
t <- template "guarded" (pure (sum [ nn * ng
| C.GNorm b _ n <- varGNorms g
, let nn = norm n
, let ng = sum [ normFromExp (e1 - e2 + 1)
| e1 :>=: e2 <- S.toList (literals b) ]]))
-- <|> template "iterate" (pure (sum [ norm (E.substitute i 1 n) * normFromExp (E.variable i) * normFromExp (E.variable i)
-- | C.GNorm _ _ n <- varGNorms g
-- , ng <- [normFromExp (l - o + 1), normFromExp (o - l + 1)] ]))
let ns = templateToNorms t
vi = E.variable i
uiDown = UpperInvariant {
inv = o .>= vi
, cnd = vi .>= l
inv = o .>= vi
, cnd = vi .>= l
, cost = g
, step = (E.substitute i (vi - 1) . C.fromNorm) `map` ns
, cont = C.zero
, limt = ns }
uiUp = UpperInvariant {
inv = vi .>= l
, cnd = o .>= vi
inv = vi .>= l
, cnd = o .>= vi
, cost = g
, step = (E.substitute i (vi + 1) . C.fromNorm) `map` ns
, cont = C.zero
, limt = ns }
ns = [ C.norm (e .>= 0) e `C.mulN` n
| C.GNorm b _ n <- varGNorms g
, e1 :>=: e2 <- S.toList (literals b)
, let e = e1 - e2 + 1]
(ui,s) <- alternative [(uiDown, o), (uiUp, l)]
E.substitute i s <$> (debugM "result" =<< solveInvariant ui)
expectation :: E.Dist E.Exp -> (E.Exp -> CExp) -> SolveM CExp
expectation (E.Discrete ls) f = discreteExpectation ls (return . f)
expectation (E.Discrete ls) f = finiteExpectation ls (return . f)
-- expectation (E.Rand (E.Constant n)) f =
-- discreteExpectation [(1,E.constant i) | i <- [0..n-1]] (return . f)
expectation (E.Rand n) f =
logBlkIdM ("[expectation] " ++ "rand(" ++ show (PP.pretty n) ++ ")") $ do
let (vi,vj) = (E.Var "@i", E.Var "@j")
(fi,fj) = (f (E.variable vi), f (E.variable vj))
logMsg2 "f" fi
-- finiteExpectation [(1,E.constant i) | i <- [0..n-1]] (return . f)
expectation r@(E.Rand n) f =
logBlk "expectation" $ do
logData "dist" r
logData "f" fv
C.guarded (n .> 0) <$>
case n of
_ | fi == fj -> return fi
-- TODO!
-- | Just g <- linFn fi vi (C.scale (1 / 2) (C.E n)) -> return g
E.Constant n' -> discreteExpectation [(1,E.constant i) | i <- [0..n'-1]] (return . f)
_ -> C.divBy <$> boundedSum f (0,n-1) <*> pure (fmap fromIntegral n)
_ | not (v `S.member` C.variables fv) -> return fv
-- | Just g <- linFn fi vi (C.scaleC (1 / 2) (C.ramp n)) -> return g
E.Constant n' -> finiteExpectation [(1,E.constant i) | i <- [0..n'-1]] (return . f)
_ -> C.divBy <$> boundedSum (v, fv) (0,n-1) <*> pure (fmap fromIntegral n)
where
noccur v e = v `notElem` E.variables e
v = E.Var "@i"
fv = f (E.variable v)
-- noccur v e = v `notElem` E.variables e
-- TODO: add linFn
-- linFn :: CExp -> E.Var -> CExp -> Maybe (CExp)
......@@ -246,53 +300,14 @@ expectation (E.Rand n) f =
-- | v `noccur` fst p = (*) (C.fromPowers [p]) <$> con ps
-- con _ = Nothing
-- linExp :: E.Exp -> E.Var -> CExp -> Maybe (CExp)
-- linExp e v avg
-- | c >= 0 && d >= 0
-- && E.constant c * E.variable v + E.constant d == e =
-- Just (C.scale (toRational c) avg + C.constant (toRational d))
-- | otherwise = Nothing
-- where
-- c = P.coefficientOf (P.fromPowers [(v,1)]) e
-- d = P.coefficientOf (P.fromPowers []) e
-- Expectation Transformer
----------------------------------------------------------------------
type NormTemplate = P.Polynomial C.Norm Int
normTemplateEq :: NormTemplate -> NormTemplate -> Bool
p1 `normTemplateEq` p2 = and $ P.coefficients $
P.zipCoefficientsWith (const False) (const False) (\ _ _ -> True) p1 p2
norm :: C.Norm -> NormTemplate
norm (C.Norm _ (E.Constant _)) = P.coefficient 1
norm n = P.variable n
normFromExp :: E.Exp -> NormTemplate
normFromExp e = norm (C.Norm (e .>= 0) e)
normFromBExp :: BExp -> NormTemplate
normFromBExp b = sum [ normFromExp (e1 - e2 + 1)
| e1 :>=: e2 <- S.toList (literals b) ]
normFromCExp :: C.CExp Rational -> NormTemplate
normFromCExp d = sum [ normFromBExp b + norm n
| C.GNorm b k n <- varGNorms d
, k > 0, not (C.isZeroN n) ]
substituteM :: (C.Norm -> SolveM NormTemplate) -> NormTemplate -> SolveM NormTemplate
substituteM s = P.fromPolynomialM s (pure . P.coefficient)
template :: String -> SolveM NormTemplate -> SolveM NormTemplate
template name = logBlkIdM (name++"-template")
extractRanking :: C -> BExp -> BExp -> CExp -> CExp -> SolveM [C.Norm]
extractRanking body i c g f = fmap toNormList $ do
extractRanking body i c g f = fmap templateToNorms $ do
let gNorm = normFromCExp g
fNorm = normFromCExp f
grdNorm = normFromBExp i + normFromBExp c
......@@ -305,10 +320,6 @@ extractRanking body i c g f = fmap toNormList $ do
<|> template "mixed" (pure (grdNorm * grdNorm + lin))
<|> template "square" (pure (lin * lin + lin))
where
toNormList p = filter (not . zeroN) [ C.prodN [ C.expN n k | (n,k) <- P.toPowers m ]
| (_,m) <- P.toMonos p]
where zeroN (C.Norm _ e) = e == 0
refine m r = do
r' <- m r
if r' `normTemplateEq` r then empty else pure r'
......@@ -366,12 +377,12 @@ data Knd = Evt | Ect deriving (Eq, Ord, Show)
ect, evt :: C -> CExp -> SolveM CExp
ect c f =
logBlkIdM ("[ect] " ++ show (PP.pretty f)) $
logMsg2 "Problem" c *> etM Ect c f
logBlk "Expected Cost" $
logData "f" f >> logData "Program" c *> etM Ect c f
evt c f =
logBlkIdM ("[evt] " ++ show (PP.pretty f)) $
logMsg2 "Problem" c *> etM Evt c f
logBlk "Expected Cost" $
logData "f" f >> logData "Program" c *> etM Evt c f
etM :: Knd -> C -> CExp -> SolveM CExp
etM = memoized et
......@@ -392,17 +403,16 @@ et Evt (Tic _) f = return f
et Ect (Tic e) f = return $ C.ramp e `C.plus` f
et _ (Ass v d) f = expectation d (\ e -> E.substitute v e f)
et t (NonDet e1 e2) f = C.sup <$> etM t e1 f <*> etM t e2 f
et t (Choice ls) f = discreteExpectation ls (\ c -> etM t c f)
et t (Choice ls) f = finiteExpectation ls (\ c -> etM t c f)
et t (Cond _ i b c1 c2) f = C.guarded i <$> (C.cond b <$> etM t c1 f <*> etM t c2 f)
et Ect (Seq c1 c2) f
| isSimple c1 = et Ect c1 =<< et Ect c2 f
| otherwise = C.plus <$> ect c1 C.zero <*> (evt c1 =<< ect c2 f)
et t (Seq c1 c2) f = etM t c1 =<< etM t c2 f
et t (While _ i b c) f =
logBlkIdM "[While.Step]" $ do
g <- case t of {Evt -> return C.zero; Ect -> ect c C.zero}
ns <- extractRanking c i b g f
logMsg2 "Norms" ns
et t d@(While _ i b c) f = logBlk "While.step" $ do
logData "Problem" d
g <- case t of {Evt -> return C.zero; Ect -> logBlk "Expected Cost Body" $ et Ect c C.zero}
ns <- logResult "Norms" $ extractRanking c i b g f
hs <- traverse (et Evt c . C.fromNorm) ns
let ui = UpperInvariant { inv = i
, cnd = b
......@@ -410,9 +420,10 @@ et t (While _ i b c) f =
, step = hs
, cont = f
, limt = ns}
logMsg2 "Invariant" ui
h <- solveInvariant ui
return (C.cond b h f)
logData "Invariant" ui
solveInvariant ui
-- h <- solveInvariant ui
-- return (C.cond b h f)
-- * Pretty
......
......@@ -242,14 +242,14 @@ assertEntailSmt :: Int -> DNF SmtLit -> SmtLit -> SmtM ()
assertEntailSmt i (Disj cs) p = forM_ cs assertEntailConj
where
assertEntailConj c
| d entailed = return ()
| entailed = return ()
| otherwise = SMT.assert =<< handelman i c p
where
entailed = p == TT || p `S.member` litsSet c || any (`entailLit` p) c
Geq0 eq `entailLit` Geq0 ep
| E.Constant (E.Constant n) <- ep - eq = n >= 0
_ `entailLit` _ = False
d = debug (show (PP.pretty (Disj cs) PP.<> PP.text " |- " PP.<> PP.pretty p))
-- d = debug (show (PP.pretty (Disj cs) PP.<> PP.text " |- " PP.<> PP.pretty p))
assertConstraintSmt :: Int -> Constraint SmtPoly -> SmtM ()
assertConstraintSmt i (GEQ g l r) = assertEntailSmt i g (l `geqLit` r)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment