Commit a6b6c13c authored by AVANZINI Martin's avatar AVANZINI Martin
Browse files

new strategy for templates

parent e75ecf8c
......@@ -90,7 +90,7 @@ fromNorm :: Num c => Norm -> CExp c
fromNorm = N 1
zero :: (Num c) => CExp c
zero = N 1 (Norm Top 0)
zero = N 0 (Norm Top 1)
one :: (Num c) => CExp c
one = N 1 (Norm Top 1)
......@@ -99,8 +99,7 @@ plus :: Coeff c => CExp c -> CExp c -> CExp c
plus (Cond g c1 c2) (Cond g' c1' c2')
| g == g' = Cond g (c1 `plus` c1') (c2 `plus` c2')
plus (Div c1 e1) (Div c2 e2) | e1 == e2 = (c1 `plus` c2) `divBy` e1
plus (N k1 n1) (N k2 n2) | k1 == k2 = nm k1 (n1 `plusN` n2)
| n1 == n2 = nm (k1 + k2) n1
plus (N k1 n1) (N k2 n2) | n1 == n2 = nm (k1 + k2) n1 -- | k1 == k2 = nm k1 (n1 `plusN` n2)
plus c1 c2 | c1 == zero = c2
| c2 == zero = c1
| otherwise = Plus c1 c2
......@@ -120,20 +119,32 @@ sup c1 c2 = Sup c1 c2
cond :: (Eq c, Num c) => BExp -> CExp c -> CExp c -> CExp c
cond Top c _ = c
cond Bot _ d = d
cond _ c d
| c == d && c /= oneV = c
cond g (Cond g' c1 c2) d
| c2 == d = cond (g .&& g') c1 d
cond _ c d
| c == d = c
cond g c d = Cond g c d
simplify :: CExp Rational -> CExp Rational
simplify (N k n) = nm k n
simplify (Div c e) = simplify c `divBy` e
simplify (Plus c d) = simplify c `plus` simplify d
simplify (Sup c d) = simplify c `sup` simplify d
simplify (Cond b c d) =
case cond b (simplify c) (simplify d) of
Cond _ c' d' | c' == d' -> c'
r -> r
guarded :: (Eq c, Num c) => BExp -> CExp c -> CExp c
guarded g c = cond g c zero
oneV :: Num c => CExp c
oneV = N 1 (Norm Top "@@ONE@@")
-- oneV :: Num c => CExp c
-- oneV = N 1 (Norm Top "@@ONE@@")
pred :: (Eq c, Num c) => BExp -> CExp c
pred g = cond g oneV zero
predicate :: (Eq c, Num c) => BExp -> CExp c
predicate g = cond g one zero
-- class RatioFactor a where rfactor :: a -> (Rational, a)
......@@ -182,24 +193,6 @@ instance Coeff c => E.Substitutable (CExp c) where
substitute s e (Sup c d) = E.substitute s e c `sup` E.substitute s e d
substitute s e (Cond b c d) = cond (E.substitute s e b) (E.substitute s e c) (E.substitute s e d)
instance PP.Pretty Norm where
pretty (Norm g e) = PP.brackets (PP.hang 0 (PP.pretty e PP.</> PP.text "|" PP.<+> PP.pretty g))
instance (PP.Pretty c, Eq c, Num c) => PP.Pretty (CExp c) where
pretty = pp id where
pp par (N 1 n) = par (ppNorm n)
pp par (N k n) = PP.pretty k PP.<> PP.text "·" PP.<> par (ppNorm n)
pp par (Div c e) = par (infx (pp id c) "/" (PP.pretty e))
pp _ (Plus c d) = infx (pp PP.parens c) "+" (pp PP.parens d)
pp _ (Sup c d) = ppFun "sup" [ PP.pretty c, PP.pretty d ]
pp _ (Cond g c d)
| d == zero = PP.hang 1 (PP.brackets (PP.pretty g) PP.<> PP.text "·" PP.<> pp PP.parens c)
| otherwise = PP.text "ite" PP.<> PP.tupled [ PP.pretty g, pp id c, pp id d ]
ppFun f ls = PP.text f PP.<> PP.tupled ls
infx l s r = PP.hang 1 (l PP.</> PP.text s PP.<+> r)
ppNorm (Norm _ e) = PP.pretty e
-- guarded norms
data GNorm c = GNorm BExp c Norm
......@@ -224,7 +217,7 @@ fromGNorm (GNorm g k n) = guarded g (N k n)
gNorms :: (Eq c, Num c) => CExp c -> [GNorm c]
gNorms = walk B.Top 1 where
walk B.Bot _ _ = []
walk b k (N _ e) = [GNorm b k e]
walk b k (N k' e) = [GNorm b (k * k') e]
walk b k (Div c _) = walk b k c
walk b k (Plus c d) = walk b k c ++ walk b k d
walk b k (Sup c d) = walk b k c ++ walk b k d
......@@ -233,5 +226,28 @@ gNorms = walk B.Top 1 where
norms :: (Eq c, Num c) => CExp c -> S.Set Norm
norms c = S.fromList [n | GNorm _ _ n <- gNorms c]
-- pretty printers
instance PP.Pretty Norm where
pretty (Norm g e) = PP.brackets (PP.hang 0 (PP.pretty e PP.</> PP.text "|" PP.<+> PP.pretty g))
instance (PP.Pretty c, Eq c, Num c) => PP.Pretty (CExp c) where
pretty = pp id where
-- pp _ (N 0 _) = PP.text "0"
-- pp par (N 1 n) = par (ppNorm n)
-- pp par (N k (Norm _ 1)) = PP.pretty k
pp _ (N k n) = PP.pretty k PP.<> PP.text "·" PP.<> PP.parens (ppNorm n)
pp par (Div c e) = par (infx (pp id c) "/" (PP.pretty e))
pp par (Plus c d) = par (infx (pp id c) "+" (pp id d))
pp _ (Sup c d) = ppFun "sup" [ PP.pretty c, PP.pretty d ]
pp _ (Cond g c d)
| d == zero = PP.hang 1 (PP.brackets (PP.pretty g) PP.<> PP.text "·" PP.<> pp PP.parens c)
| otherwise = PP.text "ite" PP.<> PP.tupled [ PP.pretty g, pp id c, pp id d ]
ppFun f ls = PP.text f PP.<> PP.tupled ls
infx l s r = PP.hang 1 (l PP.</> PP.text s PP.<+> r)
ppNorm (Norm _ e) = PP.pretty e
instance (PP.Pretty c, Eq c, Ord c, Num c) => PP.Pretty (GNorm c) where
pretty = PP.pretty . fromGNorm
......@@ -216,7 +216,7 @@ expectation (E.Rand n) f =
-- TODO!
-- | Just g <- linFn fi vi (C.scale (1 / 2) (C.E n)) -> return g
E.Constant n' -> discreteExpectation [(1,E.constant i) | i <- [0..n'-1]] (return . f)
_ | otherwise -> C.divBy <$> boundedSum f (0,n-1) <*> pure (fmap fromIntegral n)
_ -> C.divBy <$> boundedSum f (0,n-1) <*> pure (fmap fromIntegral n)
where
noccur v e = v `notElem` E.variables e
......@@ -257,61 +257,127 @@ expectation (E.Rand n) f =
type NormPoly = P.Polynomial C.Norm Int
normPolyEq :: NormPoly -> NormPoly -> Bool
p1 `normPolyEq` p2 = and $ P.coefficients $ P.zipCoefficientsWith (const False) (const False) (\ _ _ -> True) p1 p2
extractRanking :: C -> BExp -> BExp -> CExp -> CExp -> SolveM [C.Norm]
extractRanking body i c g f = do
combine <-
(logMsg "additive" >> pure additive)
<|> (logMsg "multiplicative" >> pure multiplicative)
rankFromBExp <-
(logMsg "optimistic" >> pure (ranksFromBExpWith optimistic))
<|> (logMsg "shift" >> pure (ranksFromBExpWith shift))
<|> (logMsg "squared-optimistic" >> pure (squared (ranksFromBExpWith optimistic)))
gRanks <- combine rankFromBExp (c .&& i) (varGNorms g)
fRanks <- pure (sum (fromGNorm `map` C.gNorms f))
fRanks' <- (*) <$> rankFromBExp c
<*> P.fromPolynomialM (fmap P.variable . delta) (pure . P.coefficient) fRanks
return (toNormList (norm 1 + fRanks + gRanks + fRanks')) --
where
toNormList p = filter (not . zeroN) [ C.prodN [ C.expN n i | (n,i) <- P.toPowers m ]
| (_,m) <- P.toMonos p]
extractRanking body i c g f = toNormList <$> refine 1
where
toNormList p = filter (not . zeroN) [ C.prodN [ C.expN n k | (n,k) <- P.toPowers m ]
| (_,m) <- P.toMonos p]
where zeroN (C.Norm _ e) = e == 0
norm e = P.variable (C.Norm (e .>= 0) e)
fromGNorm (C.GNorm _ _ n) = P.variable n
maxE [] = 0
maxE es = foldl1 (P.zipCoefficientsWith (max 0) (max 0) max) es
additive fromBExp b gexps =
(+) (sum (fromGNorm `map` gexps))
<$> fromBExp (b .&& bigAnd [ b' | C.GNorm b' _ _ <- gexps ])
multiplicative fromBExp b gexps =
(+) <$> additive fromBExp b gexps
<*> (sum <$> sequence [ (*) (P.variable n) <$> fromBExp (b .&& b')
| C.GNorm b' _ n <- gexps])
ranksFromBExpWith sel bexp = sum <$> mapM sel (S.toList (literals bexp))
optimistic (a :>=: b) = pure $ norm (a - b + 1)
squared m x = sq <$> m x where sq n = n * n
delta n@(C.Norm _ e) = do
ns <- S.toList <$> C.norms <$> evt body (C.nm 1 n)
let f = maxE [ ne | C.Norm _ ne <- ns ] - e
return (C.norm (f .>= 0) f)
shift l@(a :>=: b) =
-- TODO: this is a mess; maybe define wp for non-recursive language?
mkNorm <$> maxE <$> concatMap diffs <$> C.gNorms <$> evt body (C.pred (a .>= b))
where
diffs (C.GNorm gd _ n)
| not (C.isZeroN n) = [ (a' - b') - (a - b) | a' :>=: b' <- S.toList (literals gd) ]
| otherwise = []
-- TODO
mkNorm (E.AddConst _ df) = norm (a - b + 1 + E.constant df)
norm = P.variable
normFromExp e = norm (C.Norm (e .>= 0) e)
normFromBExp b = sum [ normFromExp (e1 - e2 + 1)
| e1 :>=: e2 <- S.toList (literals b) ]
normFromCExp d = sum [ normFromBExp b + norm n
| C.GNorm b k n <- varGNorms d
, k > 0, not (C.isZeroN n) ]
substituteM s = P.fromPolynomialM s (pure . P.coefficient)
stdTemplate = do
let gNorm = normFromCExp (C.guarded c g)
fNorm = normFromCExp (C.guarded (neg c) f)
df <- substituteM delta fNorm
pure (normFromExp 1 + normFromBExp i + normFromBExp c
+ fNorm + gNorm + sum [ normFromBExp c * df ])
refine = template "standard" (const stdTemplate)
<|>> template "shift-avg" shiftAvg
<|>> template "shift-max" shiftMax
<|>> template "cond" conds
<|>> template "square" square
where
template name m k = logMsg name >> m k
infixr 3 <|>>
m1 <|>> m2 = \ r0 -> do { r1 <- m1 r0; check r0 r1 <|> m2 r1 } where
check r r' = if r `normPolyEq` r' then empty else pure r'
shiftAvg = substituteM s where
s n@(C.Norm _ e) = do
evn <- evt body (C.nm 1 n)
pure $ normFromExp $ maxE $ e : [2 * e - fmap floor e' | e' <- lfs evn ]
shiftMax = substituteM s where
s n@(C.Norm _ e) = do
evn <- evt body (C.nm 1 n)
pure $ normFromExp $ maxE $ e : [2 * e - e' | C.GNorm _ _ (C.Norm _ e') <- varGNorms evn ]
conds = substituteM s where
s n = do
evn <- etM Evt body (C.nm 1 n)
pure $ norm n + sum [ normFromBExp b | C.GNorm b _ _ <- varGNorms evn ]
delta n@(C.Norm _ e) = do
ns <- S.toList . C.norms <$> etM Evt body (C.nm 1 n)
let d = maxE [ ne | C.Norm _ ne <- ns ] - e
return (normFromExp d)
square n = pure (n * n)
lfs (C.Sup c1 c2) = lfs c1 ++ lfs c2
lfs (C.Cond _ c1 c2) = lfs c1 ++ lfs c2
lfs (C.Div d _) = lfs d
lfs (C.Plus c1 c2) = [e1 + e2 | e1 <- lfs c1, e2 <- lfs c2]
lfs (C.N k (C.Norm _ e)) = [E.constant k * fmap fromIntegral e]
maxE [] = 0
maxE es = E.norm (foldl1 (P.zipCoefficientsWith (max 0) (max 0) max) es)
-- trivial (b, C.Norm _ e) = b == Bot || S.null (E.variables e)
-- combine <-
-- (logMsg "additive" >> pure additive)
-- <|> (logMsg "multiplicative" >> pure multiplicative)
-- rankFromBExp <-
-- (logMsg "optimistic" >> pure (ranksFromBExpWith optimistic))
-- <|> (logMsg "shift" >> pure (ranksFromBExpWith shift))
-- <|> (logMsg "squared-optimistic" >> pure (squared (ranksFromBExpWith optimistic)))
-- gRanks <- combine rankFromBExp (c .&& i) (varGNorms g)
-- fRanks <- pure (sum (fromGNorm `map` C.gNorms f))
-- fRanks' <- (*) <$> rankFromBExp c
-- <*> P.fromPolynomialM (fmap P.variable . delta) (pure . P.coefficient) fRanks
-- return (toNormList (norm 1 + fRanks + gRanks + fRanks')) --
-- where
-- fromGNorm (C.GNorm _ _ n) = P.variable n
-- maxE [] = 0
-- maxE es = foldl1 (P.zipCoefficientsWith (max 0) (max 0) max) es
-- additive fromBExp b gexps =
-- (+) (sum (fromGNorm `map` gexps))
-- <$> fromBExp (b .&& bigAnd [ b' | C.GNorm b' _ _ <- gexps ])
-- multiplicative fromBExp b gexps =
-- (+) <$> additive fromBExp b gexps
-- <*> (sum <$> sequence [ (*) (P.variable n) <$> fromBExp (b .&& b')
-- | C.GNorm b' _ n <- gexps])
-- squared m x = sq <$> m x where sq n = n * n
-- delta n@(C.Norm _ e) = do
-- ns <- S.toList <$> C.norms <$> evt body (C.nm 1 n)
-- let f = maxE [ ne | C.Norm _ ne <- ns ] - e
-- return (C.norm (f .>= 0) f)
-- shift l@(a :>=: b) =
-- -- TODO: this is a mess; maybe define wp for non-recursive language?
-- mkNorm <$> maxE <$> concatMap diffs <$> C.gNorms <$> evt body (C.predicate (a .>= b))
-- where
-- diffs (C.GNorm gd _ n)
-- | not (C.isZeroN n) = [ (a' - b') - (a - b) | a' :>=: b' <- S.toList (literals gd) ]
-- | otherwise = []
-- -- TODO
-- mkNorm (E.AddConst _ df) = norm (a - b + 1 + E.constant df)
isConstantWrt :: CExp -> C -> Bool
f `isConstantWrt` c = C.variables f `S.disjoint` vs where
......@@ -375,10 +441,8 @@ et Ect (Seq c1 c2) f
| otherwise = C.plus <$> ect c1 C.zero <*> (evt c1 =<< ect c2 f)
et t (Seq c1 c2) f = etM t c1 =<< etM t c2 f
et t (While _ i b c) f =
-- TODO: if x not modified than x should be treated as constant, see absynthRdseql;
logBlkIdM "[While.Step]" $ do
g <- case t of {Evt -> return C.zero; Ect -> ect c C.zero}
-- TODO: simplification when g==0 possible
ns <- extractRanking c i b g f
logMsg2 "Norms" ns
-- TODO: filter those ni's for which hi cannot be computed?
......@@ -390,7 +454,6 @@ et t (While _ i b c) f =
, cont = f
, limt = ns}
logMsg2 "Invariant" ui
--TODO: check; maybe better place?
h <- solveInvariant ui
return (C.cond b h f)
......
......@@ -314,8 +314,7 @@ reduce (GEQ c (toFrac -> (lhs,f1)) (toFrac -> (rhs,f2))) = do
walkLhs _ Bot _ = return []
walkLhs (C.N c' n) ctx k = scale c' (norm n) ctx k
walkLhs (C.Cond g c1 c2) ctx k =
branch g (walkLhs c1) (walkLhs c2) ctx k
walkLhs (C.Cond g c1 c2) ctx k = branch g (walkLhs c1) (walkLhs c2) ctx k
walkLhs (C.Plus c1 c2) ctx k = add (walkLhs c1) (walkLhs c2) ctx k
walkLhs C.Div {} _ _ = error "InvariantSolver.Naive.reduce: div on lhs"
walkLhs C.Sup {} _ _ = error "InvariantSolver.Naive.reduce: sup on lhs"
......
......@@ -7,6 +7,7 @@ import Data.String
import qualified Text.PrettyPrint.ANSI.Leijen as PP
import Data.PWhile.Program (C(..))
import qualified Data.PWhile.Expression as E
import qualified Data.PWhile.CostExpression as C
import PWhile.InferEt
import PWhile.Testbed.DSL
......@@ -19,8 +20,21 @@ prettyResult r = PP.text "***" PP.<$$> PP.pretty r PP.<$$> PP.text "<<<"
run_ :: FilterResult -> (C -> SolveM a) -> Program -> IO [Result a]
run_ t m = run t . m . fst . gen
ec :: FilterResult -> Program -> IO ()
ec t p = run_ t (flip ect C.zero) p >>= mapM_ (putPrettyLn . prettyResult)
ec' :: FilterResult -> Program -> IO ()
ec' t p = run_ t (flip ect C.zero) p >>= mapM_ (putPrettyLn . prettyResult)
ev' :: FilterResult -> C.CExp Rational -> Program -> IO ()
ev' t f p = run_ t (flip evt f) p >>= mapM_ (putPrettyLn . prettyResult)
ec :: Program -> IO ()
ec = ec' Any
ev :: C.CExp Rational -> Program -> IO ()
ev = ev' Any
expr :: E.Exp -> C.CExp Rational
expr e = C.nm 1 (C.norm (e .>= 0) e)
ifThenElse :: BExp -> Program -> Program -> Program
ifThenElse = ite top
......@@ -576,6 +590,13 @@ absynthRdseql = do
-- inferred by Absynth: 2· |[y, m] | + 0.666667· |[x, n] |
-- status: fail; see Prspeed
foo :: Program
foo = do
x <- var "x"; n <- var "n"; m <- var "m"; y <- var "y"
if (y .< m)
then y .~ unif [y, y + 1]
else x .~ unif [x, x + 1, x + 2, x + 3]
absynthRdspeed :: Program
absynthRdspeed = do
x <- var "x"; n <- var "n"; m <- var "m"; y <- var "y"
......@@ -681,18 +702,6 @@ absynthSprdwalk = do
x .~ unif [ x + 0, x + 1 ]
tick
-- polynomial/complex
-- inferred by Absynth: 6· |[0, m] | · |[0, n] | + 3· |[0, n] |+ |[0, y] |
-- status: fail
-- TODO: binomial distribution? norms do not get picked properly
foo = do
m <- var "m"; n <- var "n"; y <- var "y"; x <- var "x"; w <- var "w"
let inv = Top -- x .> 0 .&& m .>= 0 .&& y .>= 0
while inv (x .< n) $ do
x .= x + 1
y .= y + m
absynthComplex :: Program
absynthComplex = do
m <- var "m"; n <- var "n"; y <- var "y"; x <- var "x"; w <- var "w"
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment