Commit ce1b2bc5 authored by AVANZINI Martin's avatar AVANZINI Martin
Browse files

all examples pass :)

parent a6b6c13c
......@@ -138,7 +138,8 @@ simplify (Cond b c d) =
guarded :: (Eq c, Num c) => BExp -> CExp c -> CExp c
guarded g c = cond g c zero
guarded g c | c == zero = zero
| otherwise = cond g c zero
-- oneV :: Num c => CExp c
-- oneV = N 1 (Norm Top "@@ONE@@")
......@@ -234,9 +235,9 @@ instance PP.Pretty Norm where
instance (PP.Pretty c, Eq c, Num c) => PP.Pretty (CExp c) where
pretty = pp id where
-- pp _ (N 0 _) = PP.text "0"
-- pp par (N 1 n) = par (ppNorm n)
-- pp par (N k (Norm _ 1)) = PP.pretty k
pp _ (N 0 _) = PP.text "0"
pp par (N 1 n) = par (ppNorm n)
pp _ (N k (Norm _ 1)) = PP.pretty k
pp _ (N k n) = PP.pretty k PP.<> PP.text "·" PP.<> PP.parens (ppNorm n)
pp par (Div c e) = par (infx (pp id c) "/" (PP.pretty e))
pp par (Plus c d) = par (infx (pp id c) "+" (pp id d))
......
......@@ -174,9 +174,12 @@ boundedSum f (l,o) = logBlkIdM msg $ bs fi where
fi = f vi
msg = "[sum] " ++ renderPretty (PP.text "sum" PP.<> PP.tupled [ PP.pretty l, PP.pretty o, PP.pretty fi])
bs g | g == C.zero = return C.zero
bs (C.Plus g1 g2) = C.plus <$> bs g1 <*> bs g2
bs g = do
bs g
| g == C.zero = return C.zero
bs (C.Plus g1 g2) = C.plus <$> bs g1 <*> bs g2
bs (C.Cond g g1 g2)
| g2 /= C.zero = C.plus <$> bs (C.guarded g g1) <*> bs (C.guarded (neg g) g2)
bs g = do
logMsg2 "norms" ns
(ui,s,dir) <- alternative [(uiDown, o, "Down"), (uiUp, l, "Up")]
logMsg2 ("SumFn["++dir++"]") ui
......@@ -198,7 +201,10 @@ boundedSum f (l,o) = logBlkIdM msg $ bs fi where
, step = (E.substitute i (vi + 1) . C.fromNorm) `map` ns
, cont = C.zero
, limt = ns }
ns = [n | C.GNorm _ _ n <- varGNorms g]
ns = [ C.norm (e .>= 0) e `C.mulN` n
| C.GNorm b _ n <- varGNorms g
, e1 :>=: e2 <- S.toList (literals b)
, let e = e1 - e2 + 1]
expectation :: E.Dist E.Exp -> (E.Exp -> CExp) -> SolveM CExp
......@@ -255,56 +261,82 @@ expectation (E.Rand n) f =
-- Expectation Transformer
----------------------------------------------------------------------
type NormPoly = P.Polynomial C.Norm Int
type NormTemplate = P.Polynomial C.Norm Int
normTemplateEq :: NormTemplate -> NormTemplate -> Bool
p1 `normTemplateEq` p2 = and $ P.coefficients $
P.zipCoefficientsWith (const False) (const False) (\ _ _ -> True) p1 p2
norm :: C.Norm -> NormTemplate
norm = P.variable
normFromExp :: E.Exp -> NormTemplate
normFromExp e = norm (C.Norm (e .>= 0) e)
normFromBExp :: BExp -> NormTemplate
normFromBExp b = sum [ normFromExp (e1 - e2 + 1)
| e1 :>=: e2 <- S.toList (literals b) ]
normFromCExp :: C.CExp Rational -> NormTemplate
normFromCExp d = sum [ normFromBExp b + norm n
| C.GNorm b k n <- varGNorms d
, k > 0, not (C.isZeroN n) ]
substituteM :: (C.Norm -> SolveM NormTemplate) -> NormTemplate -> SolveM NormTemplate
substituteM s = P.fromPolynomialM s (pure . P.coefficient)
template :: String -> SolveM NormTemplate -> SolveM NormTemplate
template name = logBlkIdM (name++"-template")
normPolyEq :: NormPoly -> NormPoly -> Bool
p1 `normPolyEq` p2 = and $ P.coefficients $ P.zipCoefficientsWith (const False) (const False) (\ _ _ -> True) p1 p2
extractRanking :: C -> BExp -> BExp -> CExp -> CExp -> SolveM [C.Norm]
extractRanking body i c g f = toNormList <$> refine 1
where
extractRanking body i c g f = fmap toNormList $ do
lin <- linearTemplate
template "linear" (pure lin)
<|> template "shift-avg" (refine shiftAvg lin)
<|> template "shift-avg" (refine shiftAvg lin)
<|> template "conditions" (refine conds lin)
<|> template "shift-max" (refine shiftMax lin)
<|> do
let sq = lin * lin
template "square" (pure sq)
where
toNormList p = filter (not . zeroN) [ C.prodN [ C.expN n k | (n,k) <- P.toPowers m ]
| (_,m) <- P.toMonos p]
where zeroN (C.Norm _ e) = e == 0
norm = P.variable
normFromExp e = norm (C.Norm (e .>= 0) e)
normFromBExp b = sum [ normFromExp (e1 - e2 + 1)
| e1 :>=: e2 <- S.toList (literals b) ]
normFromCExp d = sum [ normFromBExp b + norm n
| C.GNorm b k n <- varGNorms d
, k > 0, not (C.isZeroN n) ]
substituteM s = P.fromPolynomialM s (pure . P.coefficient)
refine m r = do
r' <- m r
if r' `normTemplateEq` r then empty else pure r'
stdTemplate = do
linearTemplate = do
let gNorm = normFromCExp (C.guarded c g)
fNorm = normFromCExp (C.guarded (neg c) f)
df <- substituteM delta fNorm
pure (normFromExp 1 + normFromBExp i + normFromBExp c
+ fNorm + gNorm + sum [ normFromBExp c * df ])
refine = template "standard" (const stdTemplate)
<|>> template "shift-avg" shiftAvg
<|>> template "shift-max" shiftMax
<|>> template "cond" conds
<|>> template "square" square
where
template name m k = logMsg name >> m k
infixr 3 <|>>
m1 <|>> m2 = \ r0 -> do { r1 <- m1 r0; check r0 r1 <|> m2 r1 } where
check r r' = if r `normPolyEq` r' then empty else pure r'
-- refine = template "standard" (const stdTemplate)
-- <|>> (\ l -> template "shift-avg" shiftAvg
-- <|>> template "shift-max" shiftMax
-- <|>> template "cond" conds
-- <|>> template "square" square)
-- infixr 3 <|>>
-- m1 <|>> m2 = \ r0 -> do { r1 <- m1 r0; check r0 r1 <|> m2 r1 } where
-- check r r' = if r `normPolyEq` r' then empty else pure r'
shiftAvg = substituteM s where
s n@(C.Norm _ e) = do
evn <- evt body (C.nm 1 n)
evn <- etM Evt body (C.nm 1 n)
pure $ normFromExp $ maxE $ e : [2 * e - fmap floor e' | e' <- lfs evn ]
shiftMax = substituteM s where
s n@(C.Norm _ e) = do
evn <- evt body (C.nm 1 n)
evn <- etM Evt body (C.nm 1 n)
pure $ normFromExp $ maxE $ e : [2 * e - e' | C.GNorm _ _ (C.Norm _ e') <- varGNorms evn ]
conds = substituteM s where
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment