Commit 55dafb2b authored by Raphael Rieu-Helft's avatar Raphael Rieu-Helft
Browse files

Refactor reflection into a single transformation with a prop parameter

parent b86c405a
......@@ -195,9 +195,9 @@ LIB_TRANSFORM = simplify_formula inlining split_goal induction \
eliminate_epsilon intro_projections_counterexmp \
intro_vc_vars_counterexmp prepare_for_counterexmp \
instantiate_predicate smoke_detector \
induction_pr prop_curry eliminate_literal reification \
induction_pr prop_curry eliminate_literal \
args_wrapper generic_arg_trans_utils case apply \
ind_itp destruct cut
ind_itp destruct cut reification
LIB_PRINTER = cntexmp_printer alt_ergo why3printer smtv1 smtv2 coq\
pvs isabelle \
......
......@@ -68,17 +68,18 @@ module Refptr
alias { result.data with x.pcontents.data }
(*ensures { result = !!x } *) (* let ... = !!x => illegal alias when used*)
(*
let (::=) (r: refp) (v: ptr int) : unit
requires { valid v 0 }
writes { r }
ensures { !!r = v }
= let _ = if true then r.pcontents.data else v.data in
r.pcontents <- incr v (Int32.of_int 0)
*)
end
(*
module A
use import map.Map
......@@ -116,4 +117,4 @@ module A
let b = get !!s in
assert { b = 33 };
end
*)
\ No newline at end of file
......@@ -499,9 +499,7 @@ let lemma norm' (x1 x2: t')
ensures { eq' x1 x2 }
= ()
meta reify_target function interp
meta reify_normalize function normalize
meta rewrite_def function interp
end
......@@ -540,6 +538,11 @@ axiom row_zeros_def:
axiom col_zeros_def:
forall m: mat F.t, i j: int. 0 <= j -> i >= col_zeros m j -> get m i j = F.zero
axiom row_zeros_nonneg:
forall m: mat F.t, i: int. 0 <= i -> 0 <= row_zeros m i
axiom col_zeros_nonneg:
forall m: mat F.t, j: int. 0 <= j -> 0 <= col_zeros m j
(*FIXME should be invariants*)
axiom set_def_changed:
......@@ -590,6 +593,12 @@ predicate (===) (m1 m2: mat F.t) =
predicate in_bounds (m: mat F.t) (i j: int) =
0 <= i < col_zeros m j /\ 0 <= j < row_zeros m i
let lemma ext_by_bounds (m1 m2: mat F.t)
requires { m1 === m2 }
requires { forall i j. in_bounds m1 i j -> get m1 i j = get m2 i j }
ensures { m1 == m2 }
= ()
lemma oob_zero:
forall m: mat F.t, i j: int. 0 <= i -> 0 <= j -> not in_bounds m i j
-> get m i j = F.zero
......@@ -637,11 +646,11 @@ module InfMatrix
axiom create_rowz:
forall rz cz: int -> int, f: int -> int -> t, i: int.
0 <= i -> row_zeros (create rz cz f) i = rz i
0 <= i -> 0 <= rz i -> row_zeros (create rz cz f) i = rz i
axiom create_colz:
forall rz cz: int -> int, f: int -> int -> t, j: int.
0 <= j -> col_zeros (create rz cz f) j = cz j
0 <= j -> 0 <= cz j -> col_zeros (create rz cz f) j = cz j
axiom create_get_ib:
forall rz cz: int -> int, f: int -> int -> t, i j: int.
......@@ -652,12 +661,14 @@ module InfMatrix
0 <= i -> 0 <= j -> (i >= cz j \/ j >= rz i) ->
get (create rz cz f) i j = tzero
function set (m: mat) (i j:int) (v:t) : mat =
if 0 <= i /\ 0 <= j
then
create
(fun i1 -> if i1 = i then max (j+1) (row_zeros m i) else row_zeros m i1)
(fun j1 -> if j1 = j then max (i+1) (col_zeros m j) else col_zeros m j1)
(fun i1 j1 -> if i1 = i && j1 = j then v else get m i1 j1)
else m
clone export InfMatrixGen with type mat 'a = mat,
type F.t = t,
......@@ -673,16 +684,16 @@ module InfMatrix
lemma set_def_rowz_unchanged,
lemma set_def_other_rowz,
lemma set_def_other_colz
(*
lemma create_ib:
forall rz cz: int -> int, f: int -> int -> t, i j: int.
in_bounds (create rz cz f) i j -> get (create rz cz f) i j = f i j
*)
function fcreate (r c: int) (f: int -> int -> t) : mat =
create (fun _ -> c) (fun _ -> r) f
(*create (fun i -> if 0 <= i < r then c else 0)
(fun j -> if 0 <= j < c then r else 0)
f*)
create (fun _ -> max 0 c) (fun _ -> max 0 r) f
lemma fcreate_get_ib:
forall r c i j: int, f: int -> int -> t.
......@@ -693,7 +704,51 @@ module InfMatrix
0 <= i -> 0 <= j -> (i >= r \/ j >= c) -> get (fcreate r c f) i j = tzero
lemma fcreate_size:
forall r c: int, f: int -> int -> t. size (fcreate r c f) r c
forall r c: int, f: int -> int -> t. 0 <= r -> 0 <= c ->
size (fcreate r c f) r c
end
module Sum_extended
use import int.Int
use import int.Sum
function addf (f g:int -> int) : int -> int = fun x -> f x + g x
function smulf (f:int -> int) (l:int) : int -> int = fun x -> l * f x
let rec lemma sum_mult (f:int -> int) (a b l:int) : unit
ensures { sum (smulf f l) a b = l * sum f a b }
variant { b - a }
= if b > a then sum_mult f a (b-1) l
let rec lemma sum_add (f g:int -> int) (a b:int) : unit
ensures { sum (addf f g) a b = sum f a b + sum g a b }
variant { b - a }
= if b > a then sum_add f g a (b-1)
function sumf (f:int -> int -> int) (a b:int) : int -> int = fun x -> sum (f x) a b
let rec lemma fubini (f1 f2: int -> int -> int) (a b c d: int) : unit
requires { forall x y. a <= x < b /\ c <= y < d -> f1 x y = f2 y x }
ensures { sum (sumf f1 c d) a b = sum (sumf f2 a b) c d }
variant { b - a }
= if b <= a
then assert { forall x. sumf f2 a b x = 0 }
else begin
fubini f1 f2 a (b-1) c d;
assert { let ha = addf (sumf f2 a (b-1)) (f1 (b-1)) in
sum (sumf f2 a b) c d = sum ha c d
by forall y. c <= y < d -> sumf f2 a b y = ha y }
end
let ghost sum_ext (f g: int -> int) (a b: int) : unit
requires {forall i. a <= i < b -> f i = g i }
ensures { sum f a b = sum g a b }
= ()
end
......@@ -912,9 +967,11 @@ module InfIntMatrix
= 0 + (get id i i)*(get m i j) + 0
= get m i j)
/\ (forall i j. 0 <= i -> 0 <= j -> not (in_bounds m i j)
-> mul_cell id m i j = 0 = get m i j))
-> mul_cell id m i j
= sum (fun k -> mul_atom id m i j k) 0 (mul_cell_bound id m i j)
= 0 = get m i j))
use import sum_extended.Sum_extended (* examples/verify_this_2016/matrix_multiplication *)
use import Sum_extended
function ft1 (a b c: mat) (i j: int) : int -> int -> int =
fun k -> smulf (mul_atom a b i k) (get c k j)
......@@ -923,25 +980,111 @@ module InfIntMatrix
fun k -> smulf (mul_atom b c k j) (get a i k)
let lemma mul_assoc_get (a b c: mat) (i j: int)
(* requires { 0 <= i < row_zeros a j /\ 0 <= j < col_zeros c i } *)
requires { 0 <= i /\ 0 <= j }
ensures { get (mul (mul a b) c) i j = get (mul a (mul b c)) i j }
requires { 0 <= i /\ 0 <= j }
ensures { get (mul (mul a b) c) i j = get (mul a (mul b c)) i j }
= let ft1 = ft1 a b c i j in
let ft2 = ft2 a b c i j in
fubini ft1 ft2 0 (col_zeros b i) 0 (col_zeros a i);
sum_ext (mul_atom (mul a b) c i j) (sumf ft1 0 (col_zeros a i)) 0 (col_zeros b i);
assert { get (mul (mul a b) c) i j = sum (sumf ft1 0 (col_zeros a i)) 0 (col_zeros b i) };
sum_ext (mul_atom a (mul b c) i j) (sumf ft2 0 (col_zeros b i)) 0 (col_zeros a i);
assert { get (mul a (mul b c)) i j = sum (sumf ft2 0 (col_zeros b i)) 0 (col_zeros a i) }
lemma mul_assoc:
forall a b c: mat.
let ab = mul a b in
let bc = mul b c in
let a_bc = mul a bc in
let ab_c = mul ab c in
a_bc = ab_c
by a_bc == ab_c
let ab = mul a b in
let bc = mul b c in
let m_ab_c = mul_cell_bound ab c i j in
let m_a_bc = mul_cell_bound a bc i j in
fubini ft1 ft2 0 m_ab_c 0 m_a_bc;
assert { forall k. 0 <= k < m_ab_c -> mul_cell_bound a b i k <= m_a_bc
by mul_cell_bound a b i k <= row_zeros a i
so mul_cell_bound a b i k <= col_zeros b k
so col_zeros bc j = maxf (fun k -> col_zeros b k) 0 (col_zeros c j)
so 0 <= k < col_zeros c j
so col_zeros b k <= maxf (fun k -> col_zeros b k) 0 (col_zeros c j)
so col_zeros b k <= col_zeros bc j };
assert { forall k. 0 <= k < m_ab_c ->
sumf ft1 0 m_a_bc k = sumf ft1 0 (mul_cell_bound a b i k) k
by sumf ft1 0 m_a_bc k
= sum (ft1 k) 0 m_a_bc
= sum (ft1 k) 0 (mul_cell_bound a b i k)
+ sum (ft1 k) (mul_cell_bound a b i k) m_a_bc
= sumf ft1 0 (mul_cell_bound a b i k) k
+ sum (ft1 k) (mul_cell_bound a b i k) m_a_bc
so forall l. l >= mul_cell_bound a b i k -> ft1 k l = 0
so sum (ft1 k) (mul_cell_bound a b i k) m_a_bc = 0 };
assert { forall k. 0 <= k < m_ab_c ->
mul_atom ab c i j k = sumf ft1 0 m_a_bc k
by get ab i k = mul_cell a b i k
so sumf ft1 0 m_a_bc k
= sumf ft1 0 (mul_cell_bound a b i k) k
= sum (ft1 k) 0 (mul_cell_bound a b i k)
= sum (smulf (mul_atom a b i k) (get c k j))
0 (mul_cell_bound a b i k)
= get c k j * sum (mul_atom a b i k) 0 (mul_cell_bound a b i k)
= get c k j * get ab i k
= mul_atom ab c i j k };
sum_ext (mul_atom ab c i j) (sumf ft1 0 m_a_bc) 0 m_ab_c;
assert { get (mul ab c) i j = sum (sumf ft1 0 m_a_bc) 0 m_ab_c };
assert { forall k. 0 <= k < m_a_bc -> mul_cell_bound b c k j <= m_ab_c
by mul_cell_bound b c k j <= col_zeros c j
so mul_cell_bound b c k j <= row_zeros b k
so row_zeros ab i = maxf (fun k -> row_zeros b k) 0 (row_zeros a i)
so 0 <= k < row_zeros a i
so row_zeros b k <= maxf (fun k -> row_zeros b k) 0 (row_zeros a i)
so row_zeros b k <= row_zeros ab i };
assert { forall k. 0 <= k < m_a_bc ->
sumf ft2 0 m_ab_c k = sumf ft2 0 (mul_cell_bound b c k j) k
by sumf ft2 0 m_ab_c k
= sum (ft2 k) 0 m_ab_c
= sum (ft2 k) 0 (mul_cell_bound b c k j)
+ sum (ft2 k) (mul_cell_bound b c k j) m_ab_c
= sumf ft2 0 (mul_cell_bound b c k j) k
+ sum (ft2 k) (mul_cell_bound b c k j) m_ab_c
so forall l. l >= mul_cell_bound b c k j -> ft2 k l = 0
so sum (ft2 k) (mul_cell_bound b c k j) m_ab_c = 0 };
assert { forall k. 0 <= k < m_a_bc ->
mul_atom a bc i j k = sumf ft2 0 m_ab_c k
by get bc k j = mul_cell b c k j
so sumf ft2 0 m_ab_c k
= sumf ft2 0 (mul_cell_bound b c k j) k
= sum (ft2 k) 0 (mul_cell_bound b c k j)
= sum (smulf (mul_atom b c k j) (get a i k))
0 (mul_cell_bound b c k j)
= get a i k * sum (mul_atom b c k j) 0 (mul_cell_bound b c k j)
= get a i k * get bc k j
= mul_atom a bc i j k };
sum_ext (mul_atom a bc i j) (sumf ft2 0 m_ab_c) 0 m_a_bc;
assert { get (mul a bc) i j = sum (sumf ft2 0 m_ab_c) 0 m_a_bc }
lemma mul_assoc:
forall a b c: mat.
let ab = mul a b in
let bc = mul b c in
let a_bc = mul a bc in
let ab_c = mul ab c in
a_bc = ab_c by a_bc == ab_c
let lemma mul_distr_right_get (a b c: mat) (i j: int)
requires { 0 <= i /\ 0 <= j }
ensures { get (mul (add a b) c) i j = get (add (mul a c) (mul b c)) i j }
= let mac = mul_atom a c i j in
let mbc = mul_atom b c i j in
let b_ac = mul_cell_bound a c i j in
let b_bc = mul_cell_bound b c i j in
let ma = max b_ac b_bc in
assert { get (add (mul a c) (mul b c)) i j = sum (addf mac mbc) 0 ma
by sum mac 0 ma = sum mac 0 b_ac + sum mac b_ac ma
so forall k. k >= b_ac -> mac k = 0
so sum mac b_ac ma = 0
so sum mac 0 b_ac = sum mac 0 ma
so sum mbc 0 ma = sum mbc 0 b_bc + sum mbc b_bc ma
so forall k. k >= b_bc -> mbc k = 0
so sum mbc b_bc ma = 0
so sum mbc 0 b_bc = sum mbc 0 ma
so get (mul a c) i j = sum mac 0 b_ac
so get (mul b c) i j = sum mbc 0 b_bc
so get (add (mul a c) (mul b c)) i j
= get (mul a c) i j + get (mul b c) i j
= sum mac 0 b_ac + sum mbc 0 b_bc
= sum mac 0 ma + sum mbc 0 ma
= sum (addf mac mbc) 0 ma };
sum_ext (addf mac mbc) (mul_atom (add a b) c i j) 0 ma;
assert { get (mul (add a b) c) i j = mul_cell (add a b) c i j }
(* External product *)
......@@ -1010,25 +1153,28 @@ use import int.Int
let predicate eq0_int (x:int) = x = 0
clone export AssocAlgebraDecision with type r = int, type a = mat, constant R.zero = Int.zero, constant R.one = Int.one, function R.(+) = (+), function R.(-_) = (-_), function R.(*) = (*),constant A.zero = mzero, constant one = id, function (+) = add, function A.(-_) = opp, function ( *) = mul, function asub = sub, function ($) = extp, goal AUnitary, goal ANonTrivial, goal ExtDistSumA, goal ExtDistSumR, goal AssocMulExt, goal UnitExt, goal CommMulExt, val eq0 = eq0_int, goal A.MulAssoc.Assoc, goal A.Unit_def_l, goal A.Unit_def_r, goal A.Comm, goal A.Assoc, goal A.Mul_distr_l, goal A.Mul_distr_r, goal asub_def
clone export AssocAlgebraDecision with type r = int, type a = mat, constant R.zero = Int.zero, constant R.one = Int.one, function R.(+) = (+), function R.(-_) = (-_), function R.(*) = (*),constant A.zero = mzero, constant one = id, function (+) = add, function A.(-_) = opp, function ( *) = mul, function asub = sub, function ($) = extp, goal AUnitary, goal ANonTrivial, goal ExtDistSumA, goal ExtDistSumR, goal AssocMulExt, goal UnitExt, goal CommMulExt, val eq0 = eq0_int, goal A.MulAssoc.Assoc, goal A.Unit_def_l, goal A.Unit_def_r, goal A.Comm, goal A.Assoc, goal A.Mul_distr_l, goal A.Mul_distr_r, goal asub_def, goal A.Inv_def_l, goal A.Inv_def_r
end
module MatrixTests
use import InfIntMatrixDecision
use import InfIntMatrix
use import int.Int
use import InfIntMatrixDecision
use import int.Sum
use import sum_extended.Sum_extended
use import Sum_extended
function cols (a: mat) : int (* if matrix is a finite rectangle, return number of cols *)
function rows (a: mat) : int
(* lemma t: forall a: mat, r1 r2 c: int. size a r1 c -> size a r2 c -> r1 = r2*)
axiom rows_def:
forall a: mat, r c: int. size a r c -> rows a = r
forall a: mat, r c: int. 0 <= r -> 0 <= c -> size a r c -> rows a = r
axiom cols_def:
forall a: mat, r c: int. size a r c -> cols a = c
forall a: mat, r c: int. 0 <= r -> 0 <= c -> size a r c -> cols a = c
predicate is_finite (m: mat) = size m m.rows m.cols
......@@ -1038,6 +1184,7 @@ module MatrixTests
function block (a: mat) (r dr c dc: int) : mat =
fcreate dr dc (ofs2 a r c)
predicate c_blocks (a a1 a2: mat) =
0 <= a1.cols <= a.cols /\ a1 = block a 0 a.rows 0 a1.cols /\
a2 = block a 0 a.rows a1.cols (a.cols - a1.cols)
......@@ -1055,24 +1202,47 @@ module MatrixTests
sum (mul_atom a b i j) 0 k = sum (mul_atom a1 b1 i j) 0 k }
ensures { forall i j. 0 <= i < a.rows -> 0 <= j < b.cols ->
a1.cols <= k <= a.cols ->
sum (mul_atom a b i j) 0 k =
sum (mul_atom a b i j) 0 k =
sum (mul_atom a1 b1 i j) 0 a1.cols +
sum (mul_atom a2 b2 i j) 0 (k - a1.cols) }
variant { k }
= if 0 < k then begin
let k = k - 1 in
assert { forall i j. 0 <= i < a.rows -> 0 <= j < b.cols ->
mul_atom a b i j k = if k < a1.cols then mul_atom a1 b1 i j k
else mul_atom a2 b2 i j (k - a1.cols) };
if k < a1.cols
then mul_atom a b i j k = mul_atom a1 b1 i j k
else (mul_atom a b i j k = mul_atom a2 b2 i j (k - a1.cols)
by get a i k = get a2 i (k - a1.cols)
so get b k j = get b2 (k-a1.cols) j)};
block_mul_ij a a1 a2 b b1 b2 k
end
let lemma mul_split (a a1 a2 b b1 b2: mat) : unit
requires { is_finite a /\ is_finite b }
requires { a.cols = b.rows /\ a1.cols = b1.rows}
requires { 0 < a.rows /\ 0 < a.cols /\ 0 < b.cols
/\ 0 < a1.cols /\ 0 < a2.cols }
requires { c_blocks a a1 a2 /\ r_blocks b b1 b2 }
ensures {add (mul a1 b1) (mul a2 b2) = mul a b }
ensures { add (mul a1 b1) (mul a2 b2) = mul a b }
= block_mul_ij a a1 a2 b b1 b2 a.cols;
assert { add (mul a1 b1) (mul a2 b2) == mul a b }
mul_sizes a b a.rows a.cols b.cols;
mul_sizes a1 b1 a.rows a1.cols b.cols;
mul_sizes a2 b2 a.rows (a.cols - a1.cols) b.cols;
assert { add (mul a1 b1) (mul a2 b2) === mul a b
by size (add (mul a1 b1) (mul a2 b2)) a.rows b.cols
so size (mul a b) a.rows b.cols };
assert { forall i j. in_bounds (mul a b) i j ->
get (add (mul a1 b1) (mul a2 b2)) i j = get (mul a b) i j
by mul_cell_bound a1 b1 i j = a1.cols
so mul_cell_bound a2 b2 i j = a2.cols = a.cols - a1.cols
so get (mul a b) i j
= mul_cell a b i j
= sum (mul_atom a b i j) 0 a.cols
= sum (mul_atom a1 b1 i j) 0 a1.cols
+ sum (mul_atom a2 b2 i j) 0 (a.cols - a1.cols)
= get (mul a1 b1) i j + get (mul a2 b2) i j
= get (add (mul a1 b1) (mul a2 b2)) i j };
ext_by_bounds (add (mul a1 b1) (mul a2 b2)) (mul a b)
let lemma mul_block_cell (a b: mat) (r dr c dc i j: int) : unit
requires { is_finite a /\ is_finite b }
......@@ -1090,15 +1260,15 @@ module MatrixTests
= sum (mul_atom a' b' i j) 0 a.cols
= get (mul a' b') i j }
lemma mul_block:
forall a b: mat, r dr c dc: int.
a.cols = b.rows -> 0 <= r <= r + dr <= a.rows ->
0 <= c <= c + dc <= b.cols ->
let a' = block a r dr 0 a.cols in
let b' = block b 0 b.rows c dc in
let m' = block (mul a b) r dr c dc in
m' = mul a' b'
by m' == mul a' b'
let lemma mul_block (a b a' b' m': mat) (r dr c dc: int)
requires { a.cols = b.rows }
requires { 0 <= r <= r + dr <= a.rows }
requires { 0 <= c <= c + dc <= b.cols }
requires { a' = block a r dr 0 a.cols }
requires { b' = block b 0 b.rows c dc }
requires { m' = block (mul a b) r dr c dc }
ensures { m' = mul a' b' }
= assert { m' == mul a' b' }
predicate quarters (a a11 a12 a21 a22: mat) =
(is_finite a /\ is_finite a11 /\ is_finite a12 /\ is_finite a21 /\ is_finite a22) /\
......@@ -1106,20 +1276,7 @@ module MatrixTests
rows a = cols a = 2 * rows a11 /\
a11 = block a 0 a11.rows 0 a11.cols /\ a12 = block a 0 a11.rows a11.cols a11.cols /\
a21 = block a a11.rows a11.rows 0 a11.cols /\ a22 = block a a11.rows a11.rows a11.cols a11.cols
(*
let lemma naive_blocks (a b c a11 a12 a21 a22 b11 b12 b21 b22 c11 c12 c21 c22: mat)
requires { is_finite a /\ is_finite b /\ is_finite c }
requires { quarters a a11 a12 a21 a22 }
requires { quarters b b11 b12 b21 b22 }
requires { quarters c c11 c12 c21 c22 }
requires { c11 = add (mul a11 b11) (mul b11 b22) }
requires { c12 = add (mul a11 b12) (mul a12 b22) }
requires { c21 = add (mul a21 b11) (mul a22 b21) }
requires { c22 = add (mul a21 b12) (mul a22 b22) }
ensures { c = mul a b }
=
assert { c == mul a b } Z3 proves this ?
*)
let lemma naive_blocks (a b c a11 a12 a21 a22 b11 b12 b21 b22 c11 c12 c21 c22: mat)
requires { is_finite a /\ is_finite b /\ is_finite c }
requires { quarters a a11 a12 a21 a22 }
......
open Term
open Ty
open Decl
open Theory
open Ident
let meta_reify_target = Theory.register_meta_excl "reify_target" [Theory.MTlsymbol]
~desc:"Declares@ the@ given@ interpretation@ function@ as@ the@ function@ to@ be@ inverted@ at@ reification."
let meta_normalize_function = Theory.register_meta_excl "reify_normalize" [Theory.MTlsymbol]
~desc:"Declares@ the@ given@ function@ as@ the@ normalization@ function@ for@ reified@ terms@."
open Args_wrapper
(* target: t = V int | ...
interp: t -> (int -> 'a) -> 'a *)
......@@ -26,45 +20,13 @@ exception Exit
let debug = true
let expl_reified_goal = Ident.create_label "expl:reified goal"
let expl_reification_check = Ident.create_label "expl:reification check"
let expl_normalized_goal = Ident.create_label "expl:normalized goal"
let expl_normalization_check = Ident.create_label "expl:normalization check"
let collect_reify_targets_t =
Trans.on_meta_excl meta_reify_target
(function
| None ->
if debug then Format.printf "no reify target declared@.";
raise Exit
| Some [Theory.MAls i]
-> Trans.return i
| _ -> assert false)
let collect_normalize_t interp =
Trans.on_meta_excl meta_normalize_function
(function
| None ->
if debug then Format.printf "no normalize declared@.";
raise Exit
| Some [Theory.MAls n]
-> Trans.return (interp, n)
| _ -> assert false)
let collect_interp_normalize =
Trans.bind collect_reify_targets_t collect_normalize_t
let reify_goal interp task =
let kn = Task.task_known task in
let ty_vars, ty_val = match interp.ls_args, interp.ls_value with
| [ _ty_target; ty_vars ], Some ty_val
when ty_equal ty_vars (ty_func ty_int ty_val)
-> ty_vars, ty_val
| _ -> raise Exit in
let ly = create_fsymbol (Ident.id_fresh "y") [] ty_vars in
let y = t_app ly [] (Some ty_vars) in
let rec invert_pat vl (env, fr) (p,f) t =
let reflection_by_lemma pr : Task.task Trans.tlist = Trans.store (fun task ->
let open Task in
let kn = task_known task in
let rec invert_pat vl (env, fr) interp (p,f) t =
if debug
then Format.printf "invert_pat p %a f %a t %a@."
Pretty.print_pat p Pretty.print_term f Pretty.print_term t;
......@@ -73,9 +35,9 @@ let reify_goal interp task =
| Papp (cs, [{pat_node = Pvar v1}]),
Tapp (ffa,[{t_node = Tvar vy}; {t_node = Tvar v2}]),
Tvar _
| Papp (cs, [{pat_node = Pvar v1}]),
Tapp (ffa,[{t_node = Tvar vy}; {t_node = Tvar v2}]),
Tapp(_, [])
| Papp (cs, [{pat_node = Pvar v1}]),
Tapp (ffa,[{t_node = Tvar vy}; {t_node = Tvar v2}]),
Tapp(_, [])
when ty_equal v1.vs_ty ty_int
&& Svs.mem v1 p.pat_vars
&& vs_equal v1 v2
......@@ -103,7 +65,7 @@ let reify_goal interp task =
let env, fr, rl =
fold_left3
(fun (env, fr, acc) p f t ->
let env, fr, nt = invert_pat vl (env, fr) (p, f) t in
let env, fr, nt = invert_pat vl (env, fr) interp (p, f) t in
if debug
then Format.printf "param %a matched@." Pretty.print_term t;
(env, fr, nt::acc))
......@@ -114,17 +76,17 @@ let reify_goal interp task =
(Pp.print_list Pp.comma Pretty.print_term)
(List.rev rl)
;
let t = t_app cs (List.rev rl) cs.ls_value in
if debug then Format.printf "app ok@.";
env, fr, t
let t = t_app cs (List.rev rl) cs.ls_value in
if debug then Format.printf "app ok@.";
env, fr, t
| Papp _, Tapp (ls1, _), Tapp(ls2, _) ->
if debug then Format.printf "head symbol mismatch %a %a@."
Pretty.print_ls ls1 Pretty.print_ls ls2;
raise Exit
| Por (p1, p2), _, _ ->
if debug then Format.printf "case or@.";
begin try invert_pat vl (env, fr) (p1, f) t
with Exit -> invert_pat vl (env, fr) (p2, f) t
begin try invert_pat vl (env, fr) interp (p1, f) t
with Exit -> invert_pat vl (env, fr) interp (p2, f) t
end
| Pvar _, Tvar _, Tvar _ | Pvar _, Tvar _, Tapp (_, [])
| Pvar _, Tvar _, Tconst _
......@@ -153,146 +115,101 @@ let reify_goal interp task =
let rec aux = function
| [] -> raise Exit
| tb::l ->
try invert_pat vl (env, fr) (t_open_branch tb) t
try invert_pat vl (env, fr) ls (t_open_branch tb) t
with Exit -> if debug then Format.printf "match failed@."; aux l in
aux bl
| _ -> raise Exit
in
let reify_term (env, fr) (t:term) =
let reify_term (env, fr, subst) (lv, vy) t rt =
if debug then Format.printf "reify_term %a@." Pretty.print_term t;
match t.t_node with
| Tquant (Tforall, _) ->
raise Exit (* we introduce premises before the transformation *)
| _ when oty_equal t.t_ty interp.ls_value ->
match t.t_node, rt.t_node with
| _, Tapp(interp, [{t_node = Tvar vx}; {t_node = Tvar vy'} ])
when oty_equal t.t_ty interp.ls_value && Svs.mem vx lv && vs_equal vy vy' ->
if debug then Format.printf "case interp@.";
let env, fr, x = invert_interp (env, fr) interp t in
env, fr, t_app interp [x; y] (Some ty_val)
env, fr, Mvs.add vx x subst (*t_app interp [x; y] (Some ty_val)*)
| _ ->
if debug then
Format.printf "wrong type: t.ty %a interp.ls_value %a@."
Pretty.print_ty (Opt.get t.t_ty)
Pretty.print_ty (Opt.get interp.ls_value);
(*if debug then
Format.printf "wrong type: t.ty %a interp.ls_value %a@."
Pretty.print_ty (Opt.get t.t_ty)
Pretty.print_ty (Opt.get interp.ls_value);*)
raise Exit
in
let open Task in
match task with
| Some
{ task_decl =
{ td_node = Decl { d_node = Dprop (Pgoal, _, f) } };
task_prev = prev;
} ->
begin try
if debug then Format.printf "start@.";
begin match f.t_node with
| Tapp(ls, [f1; f2]) when ls_equal ls ps_equ ->
if debug then Format.printf "case =@.";
let (env, fr, t1) = reify_term (Mterm.empty, 0) f1 in
let (env, _fr, t2) = reify_term (env, fr) f2 in
let t = t_equ t1 t2 in
if debug then Format.printf "building y map@.";
let d = create_param_decl ly in