Commit b2201ded authored by Raphael Rieu-Helft's avatar Raphael Rieu-Helft
Browse files

Add reification check

parent f3b04e9c
...@@ -331,14 +331,22 @@ constant one : a ...@@ -331,14 +331,22 @@ constant one : a
clone export AssocAlgebra with type r = r, type a = a, constant one = one clone export AssocAlgebra with type r = r, type a = a, constant one = one
type t = Var int | Add t t | Mul t t | Ext r t type t = Var int | Add t t | Mul t t | Ext r t | Sub t t
type vars = int -> a type vars = int -> a
function asub (x y:a) : a
axiom asub_def: forall x y: a. asub x y = x + (A.(-_) y)
lemma ext_minone:
forall a: a. ($) (R.(-_) R.one) a = A.(-_) a
function interp (x: t) (y: vars) : a = function interp (x: t) (y: vars) : a =
match x with match x with
| Var n -> y n | Var n -> y n
| Add x1 x2 -> interp x1 y + interp x2 y | Add x1 x2 -> interp x1 y + interp x2 y
| Mul x1 x2 -> interp x1 y * interp x2 y | Mul x1 x2 -> interp x1 y * interp x2 y
| Sub x1 x2 -> asub (interp x1 y) (interp x2 y)
| Ext r x -> ($) r (interp x y) | Ext r x -> ($) r (interp x y)
end end
...@@ -374,12 +382,13 @@ let rec lemma mon_append (x1 x2: list int) (y: vars) ...@@ -374,12 +382,13 @@ let rec lemma mon_append (x1 x2: list int) (y: vars)
lemma interp_cons : forall m:m, x:t', y:vars. lemma interp_cons : forall m:m, x:t', y:vars.
interp' (Cons m x) y = interp' x y + interp' (Cons m Nil) y interp' (Cons m x) y = interp' x y + interp' (Cons m Nil) y
let rec lemma interp_sum (x1 x2: t') (y: vars) let rec lemma interp_sum (x1 x2: t')
ensures { interp' (x1++x2) y = interp' x1 y + interp' x2 y } ensures { forall y: vars.
interp' (x1++x2) y = interp' x1 y + interp' x2 y }
variant { x1 } variant { x1 }
= match x1 with = match x1 with
| Nil -> () | Nil -> ()
| Cons _ x -> interp_sum x x2 y end | Cons _ x -> interp_sum x x2 end
let ghost function append_mon (m1 m2:m) let ghost function append_mon (m1 m2:m)
ensures { forall y. interp' (Cons result Nil) y ensures { forall y. interp' (Cons result Nil) y
...@@ -409,6 +418,9 @@ let rec ghost function ext (c:r) (x:t') : t' ...@@ -409,6 +418,9 @@ let rec ghost function ext (c:r) (x:t') : t'
| Nil -> Nil | Nil -> Nil
| Cons (M r m) l -> Cons (M (R.( *) c r) m) (ext c l) end | Cons (M r m) l -> Cons (M (R.( *) c r) m) (ext c l) end
lemma ext_sub:
forall x:t', y:vars. interp' (ext (R.(-_) R.one) x) y = A.(-_) (interp' x y)
let rec ghost function conv (x:t) : t' let rec ghost function conv (x:t) : t'
ensures { forall y. interp x y = interp' result y } ensures { forall y. interp x y = interp' result y }
= match x with = match x with
...@@ -416,6 +428,7 @@ let rec ghost function conv (x:t) : t' ...@@ -416,6 +428,7 @@ let rec ghost function conv (x:t) : t'
| Add x1 x2 -> (conv x1) ++ (conv x2) | Add x1 x2 -> (conv x1) ++ (conv x2)
| Mul x1 x2 -> mul_devel (conv x1) (conv x2) | Mul x1 x2 -> mul_devel (conv x1) (conv x2)
| Ext r x -> ext r (conv x) | Ext r x -> ext r (conv x)
| Sub x1 x2 -> (conv x1) ++ (ext (R.(-_) R.one) (conv x2))
end end
...@@ -792,10 +805,12 @@ module InfIntMatrix ...@@ -792,10 +805,12 @@ module InfIntMatrix
lemma add_size: lemma add_size:
forall a b: mat, r c: int. size a r c -> size b r c -> size (add a b) r c forall a b: mat, r c: int. size a r c -> size b r c -> size (add a b) r c
by (forall i j:int. (in_bounds a i j \/ in_bounds b i j) <-> in_bounds (add a b) i j) by (forall i j:int.
(in_bounds a i j \/ in_bounds b i j) <-> in_bounds (add a b) i j)
lemma add_assoc: lemma add_assoc:
forall a b c: mat. add (add a b) c = add a (add b c) by add (add a b) c == add a (add b c) forall a b c: mat. add (add a b) c = add a (add b c)
by add (add a b) c == add a (add b c)
lemma add_commutative: lemma add_commutative:
forall a b: mat. add a b = add b a by add a b == add b a forall a b: mat. add a b = add b a by add a b == add b a
...@@ -812,6 +827,11 @@ module InfIntMatrix ...@@ -812,6 +827,11 @@ module InfIntMatrix
function sub (a b: mat) : mat = add a (opp b) function sub (a b: mat) : mat = add a (opp b)
lemma sub_size:
forall a b: mat, r c: int. size a r c -> size b r c -> size (sub a b) r c
by (forall i j:int.
(in_bounds a i j \/ in_bounds b i j) <-> in_bounds (sub a b) i j)
lemma opp_involutive: lemma opp_involutive:
forall m. opp (opp m) = m by opp (opp m) == m forall m. opp (opp m) = m by opp (opp m) == m
...@@ -954,8 +974,31 @@ lemma assoc_mul_ext: ...@@ -954,8 +974,31 @@ lemma assoc_mul_ext:
lemma unit_ext: lemma unit_ext:
forall x: mat. extp 1 x = x by extp 1 x == x forall x: mat. extp 1 x = x by extp 1 x == x
let lemma comm_mul_ext_ij (x y: mat) (r i j: int)
requires { 0 <= i /\ 0 <= j }
ensures { get (mul (extp r x) y) i j = r * (get (mul x y) i j) }
ensures { get (mul x (extp r y)) i j = r * (get (mul x y) i j) }
=
let b = mul_cell_bound x y i j in
assert { mul_cell_bound (extp r x) y i j = b
= mul_cell_bound x (extp r y) i j };
sum_ext (mul_atom (extp r x) y i j) (smulf (mul_atom x y i j) r) 0 b;
sum_ext (mul_atom x (extp r y) i j) (smulf (mul_atom x y i j) r) 0 b;
sum_mult (mul_atom (extp r x) y i j) 0 b r;
sum_mult (mul_atom x (extp r y) i j) 0 b r;
assert { get (mul (extp r x) y) i j
= r * (get (mul x y) i j)
= get (mul x (extp r y)) i j
by get (mul (extp r x) y) i j
= mul_cell (extp r x) y i j
= r * mul_cell x y i j
= mul_cell x (extp r y) i j
= get (mul x (extp r y)) i j
so r * mul_cell x y i j = r * (get (mul x y) i j) }
lemma comm_mul_ext: lemma comm_mul_ext:
forall x y: mat, r: int. extp r (mul x y) = mul (extp r x) y = mul x (extp r y) forall x y: mat, r: int.
extp r (mul x y) = mul (extp r x) y = mul x (extp r y)
by extp r (mul x y) == mul (extp r x) y == mul x (extp r y) by extp r (mul x y) == mul (extp r x) y == mul x (extp r y)
end end
...@@ -967,7 +1010,7 @@ use import int.Int ...@@ -967,7 +1010,7 @@ use import int.Int
let predicate eq0_int (x:int) = x = 0 let predicate eq0_int (x:int) = x = 0
clone export AssocAlgebraDecision with type r = int, type a = mat, constant R.zero = Int.zero, constant R.one = Int.one, function R.(+) = (+), function R.(-_) = (-_), function R.(*) = (*),constant A.zero = mzero, constant one = id, function (+) = add, function A.(-_) = opp, function ( *) = mul, function ($) = extp, goal AUnitary, goal ANonTrivial, goal ExtDistSumA, goal ExtDistSumR, goal AssocMulExt, goal UnitExt, goal CommMulExt, val eq0 = eq0_int, goal A.MulAssoc.Assoc, goal A.Unit_def_l, goal A.Unit_def_r, goal A.Comm, goal A.Assoc, goal A.Mul_distr_l, goal A.Mul_distr_r clone export AssocAlgebraDecision with type r = int, type a = mat, constant R.zero = Int.zero, constant R.one = Int.one, function R.(+) = (+), function R.(-_) = (-_), function R.(*) = (*),constant A.zero = mzero, constant one = id, function (+) = add, function A.(-_) = opp, function ( *) = mul, function asub = sub, function ($) = extp, goal AUnitary, goal ANonTrivial, goal ExtDistSumA, goal ExtDistSumR, goal AssocMulExt, goal UnitExt, goal CommMulExt, val eq0 = eq0_int, goal A.MulAssoc.Assoc, goal A.Unit_def_l, goal A.Unit_def_r, goal A.Comm, goal A.Assoc, goal A.Mul_distr_l, goal A.Mul_distr_r, goal asub_def
end end
...@@ -1150,15 +1193,15 @@ module MatrixTests ...@@ -1150,15 +1193,15 @@ module MatrixTests
assert { c = mul a b }; assert { c = mul a b };
let m1 = strassen_pow2 (add a11 a22) (add b11 b22) (k-1) in let m1 = strassen_pow2 (add a11 a22) (add b11 b22) (k-1) in
let m2 = strassen_pow2 (add a21 a22) b11 (k-1) in let m2 = strassen_pow2 (add a21 a22) b11 (k-1) in
let m3 = strassen_pow2 a11 (add b12 (extp (-1) b22)) (k-1) in let m3 = strassen_pow2 a11 (sub b12 b22) (k-1) in
let m4 = strassen_pow2 a22 (add b21 (extp (-1) b11)) (k-1) in let m4 = strassen_pow2 a22 (sub b21 b11) (k-1) in
let m5 = strassen_pow2 (add a11 a12) b22 (k-1) in let m5 = strassen_pow2 (add a11 a12) b22 (k-1) in
let m6 = strassen_pow2 (add a21 (extp (-1) a11)) (add b11 b12) (k-1) in let m6 = strassen_pow2 (sub a21 a11) (add b11 b12) (k-1) in
let m7 = strassen_pow2 (add a12 (extp (-1) a22)) (add b21 b22) (k-1) in let m7 = strassen_pow2 (sub a12 a22) (add b21 b22) (k-1) in
let s11 = add m1 (add m4 (add m7 (extp (-1) m5))) in let s11 = add m1 (add m4 (sub m7 m5)) in
let s12 = add m3 m5 in let s12 = add m3 m5 in
let s21 = add m2 m4 in let s21 = add m2 m4 in
let s22 = add m1 (add m3 (add m6 (extp (-1) m2))) in let s22 = add m1 (add m3 (sub m6 m2)) in
assert { s11 = c11 }; assert { s11 = c11 };
assert { s12 = c12 }; assert { s12 = c12 };
assert { s21 = c21 }; assert { s21 = c21 };
......
...@@ -27,6 +27,7 @@ let debug = true ...@@ -27,6 +27,7 @@ let debug = true
let expl_reified_goal = Ident.create_label "expl:reified goal" let expl_reified_goal = Ident.create_label "expl:reified goal"
let expl_reification_check = Ident.create_label "expl:reification check"
let expl_normalized_goal = Ident.create_label "expl:normalized goal" let expl_normalized_goal = Ident.create_label "expl:normalized goal"
let expl_normalization_check = Ident.create_label "expl:normalization check" let expl_normalization_check = Ident.create_label "expl:normalization check"
...@@ -157,14 +158,9 @@ let reify_goal interp task = ...@@ -157,14 +158,9 @@ let reify_goal interp task =
aux bl aux bl
| _ -> raise Exit | _ -> raise Exit
in in
let rec reify_term (env, fr) (t:term) = let reify_term (env, fr) (t:term) =
if debug then Format.printf "reify_term %a@." Pretty.print_term t; if debug then Format.printf "reify_term %a@." Pretty.print_term t;
match t.t_node with match t.t_node with
| Tapp(ls, [t1; t2]) when ls_equal ls ps_equ ->
if debug then Format.printf "case =@.";
let (env, fr, t1) = reify_term (env, fr) t1 in
let (env, fr, t2) = reify_term (env, fr) t2 in
env, fr, t_equ t1 t2
| Tquant (Tforall, _) -> | Tquant (Tforall, _) ->
raise Exit (* we introduce premises before the transformation *) raise Exit (* we introduce premises before the transformation *)
| _ when oty_equal t.t_ty interp.ls_value -> | _ when oty_equal t.t_ty interp.ls_value ->
...@@ -178,7 +174,6 @@ let reify_goal interp task = ...@@ -178,7 +174,6 @@ let reify_goal interp task =
Pretty.print_ty (Opt.get interp.ls_value); Pretty.print_ty (Opt.get interp.ls_value);
raise Exit raise Exit
in in
let open Task in let open Task in
match task with match task with
| Some | Some
...@@ -188,7 +183,12 @@ let reify_goal interp task = ...@@ -188,7 +183,12 @@ let reify_goal interp task =
} -> } ->
begin try begin try
if debug then Format.printf "start@."; if debug then Format.printf "start@.";
let (env, _fr, t) = reify_term (Mterm.empty, 0) f in begin match f.t_node with
| Tapp(ls, [f1; f2]) when ls_equal ls ps_equ ->
if debug then Format.printf "case =@.";
let (env, fr, t1) = reify_term (Mterm.empty, 0) f1 in
let (env, _fr, t2) = reify_term (env, fr) f2 in
let t = t_equ t1 t2 in
if debug then Format.printf "building y map@."; if debug then Format.printf "building y map@.";
let d = create_param_decl ly in let d = create_param_decl ly in
let prev = Task.add_decl prev d in let prev = Task.add_decl prev d in
...@@ -208,8 +208,25 @@ let reify_goal interp task = ...@@ -208,8 +208,25 @@ let reify_goal interp task =
(id_fresh "reified_goal" (id_fresh "reified_goal"
~label:(Slab.singleton expl_reified_goal)) in ~label:(Slab.singleton expl_reified_goal)) in
let d = Decl.create_prop_decl Pgoal pr t in let d = Decl.create_prop_decl Pgoal pr t in
Task.add_decl prev d let task_r = Task.add_decl prev d in
with Exit -> task end let tc1 = t_app ps_equ [t1; f1] f.t_ty in
let tc2 = t_app ps_equ [t2; f2] f.t_ty in
let prc1 = Decl.create_prsymbol
(id_fresh "reify_check"
~label:(Slab.singleton
expl_reification_check)) in
let prc2 = Decl.create_prsymbol
(id_fresh "reify_check"
~label:(Slab.singleton
expl_reification_check)) in
let d1 = Decl.create_prop_decl Pgoal prc1 tc1 in
let d2 = Decl.create_prop_decl Pgoal prc2 tc2 in
let task_c1 = Task.add_decl prev d1 in
let task_c2 = Task.add_decl prev d2 in
[task_r; task_c1; task_c2]
| _ -> raise Exit
end
with Exit -> [task] end
| _ -> assert false | _ -> assert false
...@@ -268,7 +285,7 @@ let normalize_goal_t (interp, norm) = Trans.store (normalize_goal (interp, norm) ...@@ -268,7 +285,7 @@ let normalize_goal_t (interp, norm) = Trans.store (normalize_goal (interp, norm)
let normalize_in_goal = Trans.bind collect_interp_normalize normalize_goal_t let normalize_in_goal = Trans.bind collect_interp_normalize normalize_goal_t
let () = Trans.register_transform let () = Trans.register_transform_l
"reify_in_goal" "reify_in_goal"
~desc:"Reify@ goal@ to@ declared@ target@ datatype." ~desc:"Reify@ goal@ to@ declared@ target@ datatype."
reify_in_goal reify_in_goal
......
val meta_reify_target : Theory.meta val meta_reify_target : Theory.meta
val meta_normalize_function : Theory.meta val meta_normalize_function : Theory.meta
val reify_in_goal : Task.task Trans.trans val reify_in_goal : Task.task list Trans.trans
val normalize_in_goal : Task.task list Trans.trans val normalize_in_goal : Task.task list Trans.trans
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment