Commit ce1b2bc5 by AVANZINI Martin

### all examples pass :)

parent a6b6c13c
 ... ... @@ -138,7 +138,8 @@ simplify (Cond b c d) = guarded :: (Eq c, Num c) => BExp -> CExp c -> CExp c guarded g c = cond g c zero guarded g c | c == zero = zero | otherwise = cond g c zero -- oneV :: Num c => CExp c -- oneV = N 1 (Norm Top "@@ONE@@") ... ... @@ -234,9 +235,9 @@ instance PP.Pretty Norm where instance (PP.Pretty c, Eq c, Num c) => PP.Pretty (CExp c) where pretty = pp id where -- pp _ (N 0 _) = PP.text "0" -- pp par (N 1 n) = par (ppNorm n) -- pp par (N k (Norm _ 1)) = PP.pretty k pp _ (N 0 _) = PP.text "0" pp par (N 1 n) = par (ppNorm n) pp _ (N k (Norm _ 1)) = PP.pretty k pp _ (N k n) = PP.pretty k PP.<> PP.text "·" PP.<> PP.parens (ppNorm n) pp par (Div c e) = par (infx (pp id c) "/" (PP.pretty e)) pp par (Plus c d) = par (infx (pp id c) "+" (pp id d)) ... ...
 ... ... @@ -174,9 +174,12 @@ boundedSum f (l,o) = logBlkIdM msg \$ bs fi where fi = f vi msg = "[sum] " ++ renderPretty (PP.text "sum" PP.<> PP.tupled [ PP.pretty l, PP.pretty o, PP.pretty fi]) bs g | g == C.zero = return C.zero bs (C.Plus g1 g2) = C.plus <\$> bs g1 <*> bs g2 bs g = do bs g | g == C.zero = return C.zero bs (C.Plus g1 g2) = C.plus <\$> bs g1 <*> bs g2 bs (C.Cond g g1 g2) | g2 /= C.zero = C.plus <\$> bs (C.guarded g g1) <*> bs (C.guarded (neg g) g2) bs g = do logMsg2 "norms" ns (ui,s,dir) <- alternative [(uiDown, o, "Down"), (uiUp, l, "Up")] logMsg2 ("SumFn["++dir++"]") ui ... ... @@ -198,7 +201,10 @@ boundedSum f (l,o) = logBlkIdM msg \$ bs fi where , step = (E.substitute i (vi + 1) . C.fromNorm) `map` ns , cont = C.zero , limt = ns } ns = [n | C.GNorm _ _ n <- varGNorms g] ns = [ C.norm (e .>= 0) e `C.mulN` n | C.GNorm b _ n <- varGNorms g , e1 :>=: e2 <- S.toList (literals b) , let e = e1 - e2 + 1] expectation :: E.Dist E.Exp -> (E.Exp -> CExp) -> SolveM CExp ... ... @@ -255,56 +261,82 @@ expectation (E.Rand n) f = -- Expectation Transformer ---------------------------------------------------------------------- type NormPoly = P.Polynomial C.Norm Int type NormTemplate = P.Polynomial C.Norm Int normTemplateEq :: NormTemplate -> NormTemplate -> Bool p1 `normTemplateEq` p2 = and \$ P.coefficients \$ P.zipCoefficientsWith (const False) (const False) (\ _ _ -> True) p1 p2 norm :: C.Norm -> NormTemplate norm = P.variable normFromExp :: E.Exp -> NormTemplate normFromExp e = norm (C.Norm (e .>= 0) e) normFromBExp :: BExp -> NormTemplate normFromBExp b = sum [ normFromExp (e1 - e2 + 1) | e1 :>=: e2 <- S.toList (literals b) ] normFromCExp :: C.CExp Rational -> NormTemplate normFromCExp d = sum [ normFromBExp b + norm n | C.GNorm b k n <- varGNorms d , k > 0, not (C.isZeroN n) ] substituteM :: (C.Norm -> SolveM NormTemplate) -> NormTemplate -> SolveM NormTemplate substituteM s = P.fromPolynomialM s (pure . P.coefficient) template :: String -> SolveM NormTemplate -> SolveM NormTemplate template name = logBlkIdM (name++"-template") normPolyEq :: NormPoly -> NormPoly -> Bool p1 `normPolyEq` p2 = and \$ P.coefficients \$ P.zipCoefficientsWith (const False) (const False) (\ _ _ -> True) p1 p2 extractRanking :: C -> BExp -> BExp -> CExp -> CExp -> SolveM [C.Norm] extractRanking body i c g f = toNormList <\$> refine 1 where extractRanking body i c g f = fmap toNormList \$ do lin <- linearTemplate template "linear" (pure lin) <|> template "shift-avg" (refine shiftAvg lin) <|> template "shift-avg" (refine shiftAvg lin) <|> template "conditions" (refine conds lin) <|> template "shift-max" (refine shiftMax lin) <|> do let sq = lin * lin template "square" (pure sq) where toNormList p = filter (not . zeroN) [ C.prodN [ C.expN n k | (n,k) <- P.toPowers m ] | (_,m) <- P.toMonos p] where zeroN (C.Norm _ e) = e == 0 norm = P.variable normFromExp e = norm (C.Norm (e .>= 0) e) normFromBExp b = sum [ normFromExp (e1 - e2 + 1) | e1 :>=: e2 <- S.toList (literals b) ] normFromCExp d = sum [ normFromBExp b + norm n | C.GNorm b k n <- varGNorms d , k > 0, not (C.isZeroN n) ] substituteM s = P.fromPolynomialM s (pure . P.coefficient) refine m r = do r' <- m r if r' `normTemplateEq` r then empty else pure r' stdTemplate = do linearTemplate = do let gNorm = normFromCExp (C.guarded c g) fNorm = normFromCExp (C.guarded (neg c) f) df <- substituteM delta fNorm pure (normFromExp 1 + normFromBExp i + normFromBExp c + fNorm + gNorm + sum [ normFromBExp c * df ]) refine = template "standard" (const stdTemplate) <|>> template "shift-avg" shiftAvg <|>> template "shift-max" shiftMax <|>> template "cond" conds <|>> template "square" square where template name m k = logMsg name >> m k infixr 3 <|>> m1 <|>> m2 = \ r0 -> do { r1 <- m1 r0; check r0 r1 <|> m2 r1 } where check r r' = if r `normPolyEq` r' then empty else pure r' -- refine = template "standard" (const stdTemplate) -- <|>> (\ l -> template "shift-avg" shiftAvg -- <|>> template "shift-max" shiftMax -- <|>> template "cond" conds -- <|>> template "square" square) -- infixr 3 <|>> -- m1 <|>> m2 = \ r0 -> do { r1 <- m1 r0; check r0 r1 <|> m2 r1 } where -- check r r' = if r `normPolyEq` r' then empty else pure r' shiftAvg = substituteM s where s n@(C.Norm _ e) = do evn <- evt body (C.nm 1 n) evn <- etM Evt body (C.nm 1 n) pure \$ normFromExp \$ maxE \$ e : [2 * e - fmap floor e' | e' <- lfs evn ] shiftMax = substituteM s where s n@(C.Norm _ e) = do evn <- evt body (C.nm 1 n) evn <- etM Evt body (C.nm 1 n) pure \$ normFromExp \$ maxE \$ e : [2 * e - e' | C.GNorm _ _ (C.Norm _ e') <- varGNorms evn ] conds = substituteM s where ... ...
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!