Commit 9f7caea9 authored by Raphael Rieu-Helft's avatar Raphael Rieu-Helft
Browse files

Allow reflection on a program function's contract

parent 901d4eae
......@@ -318,7 +318,7 @@ val predicate eq0 (r: r) ensures { result <-> r = R.zero }
end
theory AssocAlgebraDecision
module AssocAlgebraDecision
use import int.Int
......@@ -330,7 +330,7 @@ type a
val constant rzero : r
val constant rone : r
val constant aone : a
val constant azero : a
val ghost constant azero : a
val function rplus r r : r
val function rtimes r r : r
......@@ -340,9 +340,9 @@ val function aplus a a : a
val function atimes a a : a
val function aopp a : a
clone export AssocAlgebra with type r = r, type a = a, constant one = aone, constant R.zero = rzero, constant R.one = rone, function R.(+) = rplus, function R.( *) = rtimes, function R.(-_) = ropp, function (+) = aplus, function ( *) = atimes, function A.(-_) = aopp
clone export AssocAlgebra with type r = r, type a = a, constant one = aone, constant A.zero = azero, constant R.zero = rzero, constant R.one = rone, function R.(+) = rplus, function R.( *) = rtimes, function R.(-_) = ropp, function (+) = aplus, function ( *) = atimes, function A.(-_) = aopp
axiom azero_def: azero = A.zero (* FIXME *)
(*axiom azero_def: azero = A.zero*) (* FIXME *)
type t = Var int | Add t t | Mul t t | Ext r t | Sub t t
type vars = int -> a
......@@ -376,7 +376,7 @@ let rec function mon (x: list int) (y: vars) : a =
| Cons x l -> atimes (y x) (mon l y)
end
let rec function interp' (x: t') (y: vars) : a =
let rec ghost function interp' (x: t') (y: vars) : a =
match x with
| Nil -> azero
| Cons (M r m) l -> aplus (($) r (mon m y)) (interp' l y) end
......@@ -513,7 +513,7 @@ let lemma norm' (x1 x2: t')
ensures { eq' x1 x2 }
= ()
let function norm_f (x1 x2: t) : bool
let norm_f (x1 x2: t) : bool
ensures { forall y: vars. result = true -> interp x1 y = interp x2 y }
= match normalize' (conv (Sub x1 x2)) with
| Nil -> true
......@@ -847,9 +847,11 @@ module InfIntMatrix
let constant zerof : int -> int -> int = fun _ _ -> 0
val constant mzero : mat
(*val constant mzero : mat
axiom mzero_def: mzero = fcreate 0 0 zerof (*FIXME*)
axiom mzero_def: mzero = fcreate 0 0 zerof*) (*FIXME*)
let ghost constant mzero : mat = fcreate 0 0 zerof
let ghost function zerorc (r c: int) : mat = fcreate r c zerof
......
......@@ -19,12 +19,12 @@ let debug = true
let expl_reification_check = Ident.create_label "expl:reification check"
type reify_env = { kn: known_map;
store: int Mterm.t;
store: (vsymbol * int) Mterm.t;
fr: int;
subst: term Mvs.t;
lv: Svs.t;
vy: vsymbol option;
ty_val: ty option;
var_maps: ty Mvs.t; (* type of values pointed by each map*)
ty_to_map: vsymbol Mty.t;
}
let init_renv kn lv = { kn=kn;
......@@ -32,56 +32,30 @@ let init_renv kn lv = { kn=kn;
fr = 0;
subst = Mvs.empty;
lv = lv;
vy = None;
ty_val = None;
var_maps = Mvs.empty;
ty_to_map = Mty.empty;
}
let rec reify_term renv t rt =
let rec invert_pat vl (env, fr) interp (p,f) t =
let rec invert_pat vl (renv:reify_env) interp (p,f) t =
if debug
then Format.printf
"invert_pat p %a f %a t %a@."
Pretty.print_pat p Pretty.print_term f Pretty.print_term t;
match p.pat_node, f.t_node, t.t_node with
| Pwild, _, _ -> raise NoReification
| Papp (cs, [{pat_node = Pvar v1}]),
Tapp (ffa,[{t_node = Tvar vy}; {t_node = Tvar v2}]),
Tvar _
| Papp (cs, [{pat_node = Pvar v1}]),
Tapp (ffa,[{t_node = Tvar vy}; {t_node = Tvar v2}]),
Tapp(_, [])
when ty_equal v1.vs_ty ty_int
&& Svs.mem v1 p.pat_vars
&& vs_equal v1 v2
&& ls_equal ffa fs_func_app
&& List.exists (fun vs -> vs_equal vs vy) vl (*FIXME*)
->
if debug then Format.printf "case var@.";
let rty = cs.ls_value in
if Mterm.mem t env
then
begin
if debug then Format.printf "%a exists@." Pretty.print_term t;
(env, fr, t_app cs [t_nat_const (Mterm.find t env)] rty)
end
else
begin
if debug then Format.printf "%a is new@." Pretty.print_term t;
let env = Mterm.add t fr env in
(env, fr+1, t_app cs [t_nat_const fr] rty)
end
| Papp (cs, pl), Tapp(ls1, la1), Tapp(ls2, la2) when ls_equal ls1 ls2
->
if debug then Format.printf "case app@.";
(* same head symbol, match parameters *)
let env, fr, rl =
let renv, rl =
fold_left3
(fun (env, fr, acc) p f t ->
let env, fr, nt = invert_pat vl (env, fr) interp (p, f) t in
(fun (renv, acc) p f t ->
let renv, nt = invert_pat vl renv interp (p, f) t in
if debug
then Format.printf "param %a matched@." Pretty.print_term t;
(env, fr, nt::acc))
(env, fr, []) pl la1 la2 in
(renv, nt::acc))
(renv, []) pl la1 la2 in
if debug then Format.printf "building app %a of type %a with args %a@."
Pretty.print_ls cs
Pretty.print_ty (Opt.get cs.ls_value)
......@@ -89,47 +63,77 @@ let rec reify_term renv t rt =
(List.rev rl);
let t = t_app cs (List.rev rl) cs.ls_value in
if debug then Format.printf "app ok@.";
env, fr, t
renv, t
| Papp _, Tapp (ls1, _), Tapp(ls2, _) ->
if debug then Format.printf "head symbol mismatch %a %a@."
Pretty.print_ls ls1 Pretty.print_ls ls2;
raise NoReification
| Por (p1, p2), _, _ ->
if debug then Format.printf "case or@.";
begin try invert_pat vl (env, fr) interp (p1, f) t
with NoReification -> invert_pat vl (env, fr) interp (p2, f) t
begin try invert_pat vl renv interp (p1, f) t
with NoReification -> invert_pat vl renv interp (p2, f) t
end
| Pvar _, Tvar _, Tvar _ | Pvar _, Tvar _, Tapp (_, [])
| Pvar _, Tvar _, Tconst _
-> if debug then Format.printf "case vars@.";
(env, fr, t)
(renv, t)
| Pvar _, Tapp (ls, _la), _ when ls_equal ls interp
-> if debug then Format.printf "case interp@.";
invert_interp (env, fr) ls t
invert_interp renv ls t
(*| Papp (cs, pl), Tapp (ls1, la1), _ when Sls.mem ls1 !reify_invert
-> (* Cst c -> morph c <- 42 ? *) *)
| _ -> raise NoReification
and invert_interp (env, fr) ls (t:term) = (*la ?*)
and invert_var_pat vl (renv:reify_env) _interp (p,f) t =
if debug
then Format.printf
"invert_var_pat p %a f %a t %a@."
Pretty.print_pat p Pretty.print_term f Pretty.print_term t;
match p.pat_node, f.t_node, t.t_node with
| Papp (cs, [{pat_node = Pvar v1}]),
Tapp (ffa,[{t_node = Tvar vy}; {t_node = Tvar v2}]), _
when ty_equal v1.vs_ty ty_int
&& Svs.mem v1 p.pat_vars
&& vs_equal v1 v2
&& ls_equal ffa fs_func_app
&& List.exists (fun vs -> vs_equal vs vy) vl (*FIXME*)
->
if debug then Format.printf "case var@.";
let rty = cs.ls_value in
if Mterm.mem t renv.store
then
begin
if debug then Format.printf "%a exists@." Pretty.print_term t;
(renv, t_app cs [t_nat_const (snd (Mterm.find t renv.store))] rty)
end
else
begin
if debug then Format.printf "%a is new@." Pretty.print_term t;
let fr = renv.fr in
let vy = Mty.find vy.vs_ty renv.ty_to_map in
let store = Mterm.add t (vy, fr) renv.store in
let renv = { renv with store = store; fr = fr + 1 } in
(renv, t_app cs [t_nat_const fr] rty)
end
| _ -> raise NoReification
and invert_interp renv ls (t:term) = (*la ?*)
let ld = Opt.get (find_logic_definition renv.kn ls) in
let vl, f = open_ls_defn ld in
(*assert (oty_equal f.t_ty t.t_ty);*)
if debug then Format.printf "invert_interp ls %a t %a@."
Pretty.print_ls ls Pretty.print_term t;
match f.t_node, t.t_node with
| Tcase (x, bl), _ ->
(*FIXME*)
assert (List.length vl = 2);
(match x.t_node with
| Tvar v when vs_equal v (List.hd vl) -> ()
| _ -> assert false);
if debug then Format.printf "case match@.";
let rec aux = function
let rec aux invert = function
| [] -> raise NoReification
| tb::l ->
try invert_pat vl (env, fr) ls (t_open_branch tb) t
try invert vl renv ls (t_open_branch tb) t
with NoReification ->
if debug then Format.printf "match failed@."; aux l in
aux bl
if debug then Format.printf "match failed@."; aux invert l in
(try aux invert_pat bl with NoReification -> aux invert_var_pat bl)
| _ -> raise NoReification in
if debug then Format.printf "reify_term t %a rt %a@."
Pretty.print_term t Pretty.print_term rt;
......@@ -139,24 +143,17 @@ let rec reify_term renv t rt =
Pretty.print_ty (Opt.get t.t_ty) Pretty.print_ty (Opt.get rt.t_ty);
raise NoReification);
match t.t_node, rt.t_node with
| _, Tapp(interp, [{t_node = Tvar vx}; {t_node = Tvar vy'} ])
when Svs.mem vx renv.lv && Svs.mem vy' renv.lv ->
| _, Tapp(interp, [{t_node = Tvar vx}; {t_node = Tvar vy} ])
when Svs.mem vx renv.lv && Svs.mem vy renv.lv ->
if debug then Format.printf "case interp@.";
if renv.vy <> None && not (vs_equal (Opt.get renv.vy) vy')
then (if debug then Format.printf "y map conflict@.";
raise NoReification);
let store, fr, x = invert_interp (renv.store, renv.fr) interp t in
let renv = { renv with store = store; fr = fr; subst = Mvs.add vx x renv.subst } in
if renv.vy = None
then begin
assert (renv.ty_val = None);
let ty_val = match interp.ls_args, interp.ls_value with
| [ _ty_target; ty_vars ], Some ty_val
when ty_equal ty_vars (ty_func ty_int ty_val)
-> ty_val
| _ -> raise NoReification in
{renv with vy = Some vy'; ty_val = Some ty_val } end
else renv
let var_maps, ty_to_map =
if Mty.mem vy.vs_ty renv.ty_to_map
then renv.var_maps, renv.ty_to_map
else (Mvs.add vy (Opt.get interp.ls_value) renv.var_maps,
Mty.add vy.vs_ty vy renv.ty_to_map) in
let renv = { renv with var_maps = var_maps; ty_to_map = ty_to_map } in
let renv, x = invert_interp renv interp t in
{ renv with subst = Mvs.add vx x renv.subst }
| Tapp(eq, [t1; t2]), Tapp (eq', [rt1; rt2])
when ls_equal eq ps_equ && ls_equal eq' ps_equ ->
if debug then Format.printf "case eq@.";
......@@ -164,31 +161,35 @@ let rec reify_term renv t rt =
| _ -> if debug then Format.printf "no reify_term match@."; raise NoReification
let build_vars_map renv prev =
if debug then Format.printf "building vars map@.";
let ty_val = Opt.get renv.ty_val in
let ty_vars = ty_func ty_int ty_val in
let ly = create_fsymbol (Ident.id_fresh "y") [] ty_vars in
let y = t_app ly [] (Some ty_vars) in
let vy = Opt.get renv.vy in
let subst = Mvs.add vy y renv.subst in
if not (Svs.for_all (fun v -> Mvs.mem v subst) renv.lv)
then (if debug
then Format.printf "some vars not matched, todo use context";
raise Exit);
let d = create_param_decl ly in
let prev = Task.add_decl prev d in
let prev = Mterm.fold
(fun t i prev ->
let et = t_equ
(t_app fs_func_app [y; t_nat_const i]
(Some ty_val))
t in
if debug then Format.printf "eq_term ok@.";
let pr = create_prsymbol (Ident.id_fresh "y_val") in
let d = create_prop_decl Paxiom pr et in
Task.add_decl prev d)
renv.store prev in
subst, prev
if debug then Format.printf "building vars map@.";
let subst, prev = Mvs.fold
(fun vy ty_val (subst, prev) ->
let ty_vars = ty_func ty_int ty_val in
let ly = create_fsymbol (Ident.id_fresh vy.vs_name.id_string)
[] ty_vars in
let y = t_app ly [] (Some ty_vars) in
let d = create_param_decl ly in
let prev = Task.add_decl prev d in
Mvs.add vy y subst, prev)
renv.var_maps (renv.subst, prev) in
if not (Svs.for_all (fun v -> Mvs.mem v subst) renv.lv)
then (if debug
then Format.printf "some vars not matched, todo use context@.";
raise Exit);
let prev = Mterm.fold
(fun t (vy,i) prev ->
let y = Mvs.find vy subst in
let ty_val = Mvs.find vy renv.var_maps in
let et = t_equ
(t_app fs_func_app [y; t_nat_const i]
(Some ty_val))
t in
if debug then Format.printf "eq_term ok@.";
let pr = create_prsymbol (Ident.id_fresh "y_val") in
let d = create_prop_decl Paxiom pr et in
Task.add_decl prev d)
renv.store prev in
subst, prev
let build_goals prev subst lp g rt =
if debug then Format.printf "building goals@.";
......@@ -239,7 +240,6 @@ let reflection_by_lemma pr : Task.task Trans.tlist = Trans.store (fun task ->
build_goals prev subst lp g rt
with NoReification | Exit -> [task])
open Mltree
open Expr
open Ity
......@@ -594,20 +594,27 @@ let rec term_of_value = function
(*exception FunctionNotFound*)
let reflection_by_function ls env = Trans.store (fun task ->
let reflection_by_function s env = Trans.store (fun task ->
if debug then Format.printf "reflection_f start@.";
let kn = task_known task in
let g, prev = Task.task_separate_goal task in
let g = Apply.term_decl g in
(*if debug then Format.printf "reading theory@.";
let th = Env.read_theory env ["compute"] ths in
let pmod = Pmodule.restore_module th in
let rs = try Pmodule.ns_find_rs pmod.Pmodule.mod_export [f]
with Not_found -> raise FunctionNotFound in*)
let rs = Expr.restore_rs ls in
let mith = Task.(used_symbols (used_theories task)) in
let th = Mid.find rs.rs_name mith in
let pmod = Pmodule.restore_module th in
let ths = Task.used_theories task in
let o =
Mid.fold
(fun _ th o ->
try
let pmod = Pmodule.restore_module th in
let rs = Pmodule.ns_find_rs pmod.Pmodule.mod_export [s] in
if o = None then Some (pmod, rs)
else (if debug then Format.printf "Name conflict %s@." s;
raise Exit)
with Not_found -> o)
ths None in
let (pmod, rs) = if o = None
then (if debug then Format.printf "Symbol %s not found@." s;
raise Exit)
else Opt.get o in
let (_, ms, _) = Pmodule.restore_path rs.rs_name in
let lpost = List.map open_post rs.rs_cty.cty_post in
if List.exists (fun pv -> pv.pv_ghost) rs.rs_cty.cty_args
......@@ -663,7 +670,7 @@ let () = wrap_and_register
let () = wrap_and_register
~desc:"reflection_f <f> attempts to prove the goal by reflection using the contract of the function f"
"reflection_f"
(Tlsymbol Tenvtrans_l) reflection_by_function
(Tstring Tenvtrans_l) reflection_by_function
(*
Local Variables:
......
val reflection_by_lemma: Decl.prsymbol -> Task.task Trans.tlist
val reflection_by_function: string -> Env.env -> Task.task Trans.tlist
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment