Commit eec946f3 authored by Andrei Paskevich's avatar Andrei Paskevich

whyml: remove redundant invariants

parent 86d1bc0f
......@@ -30,6 +30,8 @@ open Mlw_ty.T
open Mlw_expr
let debug = Debug.register_flag "whyml_wp"
let no_track = Debug.register_flag "wp_no_track"
let no_eval = Debug.register_flag "wp_no_eval"
(** Marks *)
......@@ -257,11 +259,11 @@ let decrease ?loc env olds varl =
(** Reconstruct pure values after writes *)
let find_constructors lkn kn sts ity = match ity.ity_node with
let find_constructors lkm km sts ity = match ity.ity_node with
| Itypur (ts,_) ->
let base = ity_pur ts (List.map ity_var ts.ts_args) in
let sbs = ity_match ity_subst_empty base ity in
let csl = Decl.find_constructors lkn ts in
let csl = Decl.find_constructors lkm ts in
if csl = [] || Sts.mem ts sts then Loc.errorm
"Cannot update values of type %a" Mlw_pretty.print_ity base;
let subst ty = ity_full_inst sbs (ity_of_ty ty), None in
......@@ -270,7 +272,7 @@ let find_constructors lkn kn sts ity = match ity.ity_node with
| Ityapp (its,_,_) ->
let base = ity_app its (List.map ity_var its.its_args) its.its_regs in
let sbs = ity_match ity_subst_empty base ity in
let csl = Mlw_decl.find_constructors kn its in
let csl = Mlw_decl.find_constructors km its in
if csl = [] || Sts.mem its.its_pure sts then Loc.errorm
"Cannot update values of type %a" Mlw_pretty.print_ity base;
let subst vtv =
......@@ -280,8 +282,8 @@ let find_constructors lkn kn sts ity = match ity.ity_node with
Sts.add its.its_pure sts, List.map cnstr csl
| Ityvar _ -> assert false
let analyze_var fn_down fn_join lkn kn sts vs ity =
let sts, csl = find_constructors lkn kn sts ity in
let analyze_var fn_down fn_join lkm km sts vs ity =
let sts, csl = find_constructors lkm km sts ity in
let branch (cs,ityl) =
let mk_var (ity,_) = create_vsymbol (id_fresh "y") (ty_of_ity ity) in
let vars = List.map mk_var ityl in
......@@ -368,39 +370,86 @@ let quantify env regs f =
let f = Mvs.fold update vars (subst_at_now true vv' f) in
wp_forall (List.rev (Mreg.values mreg)) f
(* value tracking *)
(** Invariants *)
let get_invariant km t =
let ty = t_type t in
let ts = match ty.ty_node with
| Tyapp (ts,_) -> ts
| _ -> assert false in
let rec find_td = function
| (its,_,inv) :: _ when ts_equal ts its.its_pure -> inv
| _ :: tdl -> find_td tdl
| [] -> assert false in
let pd = Mid.find ts.ts_name km in
let inv = match pd.Mlw_decl.pd_node with
| Mlw_decl.PDdata tdl -> find_td tdl
| _ -> assert false in
let sbs = Ty.ty_match Mtv.empty (t_type inv) ty in
let u, p = open_post (t_ty_subst sbs Mvs.empty inv) in
wp_expl expl_inv (t_subst_single u t p)
let ps_inv = Term.create_psymbol (id_fresh "inv")
[ty_var (create_tvsymbol (id_fresh "a"))]
let full_invariant lkm km vs ity =
let rec update sts vs ity _ =
if not (ity_inv ity) then t_true else
(* what is our current invariant? *)
let f = match ity.ity_node with
| Ityapp (its,_,_) when its.its_inv ->
if Debug.test_flag no_track
then get_invariant km (t_var vs)
else ps_app ps_inv [t_var vs]
| _ -> t_true in
(* what are our sub-invariants? *)
let join _ fl _ = wp_ands ~sym:true fl in
let g = analyze_var update join lkm km sts vs ity in
(* put everything together *)
wp_and ~sym:true f g
in
update Sts.empty vs ity None
(** Value tracking *)
type point = int
type value = point list Mls.t (* constructor -> field list *)
type state = {
kn : Decl.known_map;
memory : (point, value) Hashtbl.t;
next : point ref;
st_km : Mlw_decl.known_map;
st_lkm : Decl.known_map;
st_mem : (point, value) Hashtbl.t;
st_next : point ref;
}
type names = point Mvs.t
let empty_state kn = {
kn = kn;
memory = Hashtbl.create 5;
next = ref 0;
type names = point Mvs.t (* variable -> point *)
type condition = lsymbol Mint.t (* point -> constructor *)
type lesson = condition list Mint.t (* point -> conditions for invariant *)
let empty_state lkm km = {
st_km = km;
st_lkm = lkm;
st_mem = Hashtbl.create 5;
st_next = ref 0;
}
let next_point state =
let res = !(state.next) in incr state.next; res
let res = !(state.st_next) in incr state.st_next; res
let make_value state ty =
let get_p _ = next_point state in
let new_cs cs = List.map get_p cs.ls_args in
let add_cs m (cs,_) = Mls.add cs (new_cs cs) m in
let csl = match ty.ty_node with
| Tyapp (ts,_) -> Decl.find_constructors state.kn ts
| _ -> assert false in
| Tyapp (ts,_) -> Decl.find_constructors state.st_lkm ts
| _ -> [] in
List.fold_left add_cs Mls.empty csl
let match_point state ty p =
try Hashtbl.find state.memory p with Not_found ->
try Hashtbl.find state.st_mem p with Not_found ->
let value = make_value state ty in
Hashtbl.replace state.memory p value;
if not (Mls.is_empty value) then
Hashtbl.replace state.st_mem p value;
value
let rec open_pattern state names value p pat = match pat.pat_node with
......@@ -421,7 +470,7 @@ let rec point_of_term state names t = match t.t_node with
| Tvar vs ->
Mvs.find vs names
| Tapp (ls, tl) ->
begin match Mid.find ls.ls_name state.kn with
begin match Mid.find ls.ls_name state.st_lkm with
| { Decl.d_node = Decl.Ddata tdl } ->
let is_cs (cs,_) = ls_equal ls cs in
let is_cs (_,csl) = List.exists is_cs csl in
......@@ -460,7 +509,7 @@ let rec point_of_term state names t = match t.t_node with
begin try
let value = List.fold_left branch Mls.empty bl in
let value = Mls.set_union value (make_value state ty) in
Hashtbl.replace state.memory p value
Hashtbl.replace state.st_mem p value
with Exit -> () end;
p
| Tconst _ | Tif _ | Teps _ -> next_point state
......@@ -471,13 +520,13 @@ and point_of_constructor state names ls tl =
let pl = List.map (point_of_term state names) tl in
let value = make_value state (of_option ls.ls_value) in
let value = Mls.add ls pl value in
Hashtbl.replace state.memory p value;
Hashtbl.replace state.st_mem p value;
p
and point_of_projection state names ls t1 =
let ty = of_option t1.t_ty in
let csl = match ty.ty_node with
| Tyapp (ts,_) -> Decl.find_constructors state.kn ts
| Tyapp (ts,_) -> Decl.find_constructors state.st_lkm ts
| _ -> assert false in
match csl with
| [cs,pjl] ->
......@@ -490,43 +539,71 @@ and point_of_projection state names ls t1 =
find_p pjl (Mls.find cs value)
| _ -> next_point state (* more than one, can't choose *)
(* invariants *)
let get_invariant kn v =
let ts = match v.vs_ty.ty_node with
| Tyapp (ts,_) -> ts
| _ -> assert false in
let rec find_td = function
| (its,_,inv) :: _ when ts_equal ts its.its_pure -> inv
| _ :: tdl -> find_td tdl
| [] -> assert false in
let pd = Mid.find ts.ts_name kn in
let inv = match pd.Mlw_decl.pd_node with
| Mlw_decl.PDdata tdl -> find_td tdl
| _ -> assert false in
let sbs = Ty.ty_match Mtv.empty (t_type inv) v.vs_ty in
let u, p = open_post (t_ty_subst sbs Mvs.empty inv) in
wp_expl expl_inv (t_subst_single u (t_var v) p)
let ps_inv = Term.create_psymbol (id_fresh "inv")
[ty_var (create_tvsymbol (id_fresh "a"))]
let full_invariant lkn kn vs ity =
let rec update sts vs ity _ =
if not (ity_inv ity) then t_true else
(* what is our current invariant? *)
let f = match ity.ity_node with
| Ityapp (its,_,_) when its.its_inv ->
get_invariant kn vs
(* ps_app ps_inv [t_var vs] *)
| _ -> t_true in
(* what are our sub-invariants? *)
let join _ fl _ = wp_ands ~sym:true fl in
let g = analyze_var update join lkn kn sts vs ity in
(* put everything together *)
wp_and ~sym:true f g
in
update Sts.empty vs ity None
let rec track_values state names lesson cond f = match f.t_node with
| Tapp (ls, [t1]) when ls_equal ls ps_inv ->
let p1 = point_of_term state names t1 in
let condl = Mint.find_def [] p1 lesson in
let contains c1 c2 = Mint.submap (fun _ -> ls_equal) c2 c1 in
if List.exists (contains cond) condl then
lesson, t_true
else
let good c = not (contains c cond) in
let condl = List.filter good condl in
let l = Mint.add p1 (cond::condl) lesson in
l, get_invariant state.st_km t1
| Tbinop (Timplies, f1, f2) ->
let l, f1 = track_values state names lesson cond f1 in
let _, f2 = track_values state names l cond f2 in
lesson, t_label_copy f (t_implies_simp f1 f2)
| Tbinop (Tand, f1, f2) ->
let l, f1 = track_values state names lesson cond f1 in
let l, f2 = track_values state names l cond f2 in
l, t_label_copy f (t_and_simp f1 f2)
| Tif (fc, f1, f2) ->
let _, f1 = track_values state names lesson cond f1 in
let _, f2 = track_values state names lesson cond f2 in
lesson, t_label_copy f (t_if_simp fc f1 f2)
| Tcase (t1, bl) ->
let p1 = point_of_term state names t1 in
let value = match_point state (of_option t1.t_ty) p1 in
let is_pat_var = function
| { pat_node = Pvar _ } -> true | _ -> false in
let branch l br =
let pat, f1, cb = t_open_branch_cb br in
let learn, cond = match bl, pat.pat_node with
| [_], _ -> true, cond (* one branch, can learn *)
| _, Papp (cs, pl) when List.for_all is_pat_var pl ->
(try true, Mint.add_new Exit p1 cs cond (* can learn *)
with Exit -> false, cond) (* contradiction, cannot learn *)
| _, _ -> false, cond (* complex pattern, will not learn *)
in
let names = open_pattern state names value p1 pat in
let m, f1 = track_values state names lesson cond f1 in
let l = if learn then m else l in
l, cb pat f1
in
let l, bl = Util.map_fold_left branch lesson bl in
l, t_label_copy f (t_case t1 bl)
| Tlet (t1, bf) ->
let p1 = point_of_term state names t1 in
let v, f1, cb = t_open_bound_cb bf in
let names = Mvs.add v p1 names in
let l, f1 = track_values state names lesson cond f1 in
l, t_label_copy f (t_let_simp t1 (cb v f1))
| Tquant (Tforall, qf) ->
let vl, trl, f1, cb = t_open_quant_cb qf in
let add_vs s vs = Mvs.add vs (next_point state) s in
let names = List.fold_left add_vs names vl in
let l, f1 = track_values state names lesson cond f1 in
l, t_label_copy f (t_forall_simp (cb vl trl f1))
| Tbinop ((Tor|Tiff),_,_) | Tquant (Texists,_)
| Tapp _ | Tnot _ | Ttrue | Tfalse -> lesson, f
| Tvar _ | Tconst _ | Teps _ -> assert false
let track_values lkm km f =
let state = empty_state lkm km in
let _, f = track_values state Mvs.empty Mint.empty Mint.empty f in
f
(** Weakest preconditions *)
......@@ -815,7 +892,7 @@ let rec unabsurd f = match f.t_node with
| _ ->
t_map unabsurd f
let add_wp_decl name f uc =
let add_wp_decl km name f uc =
(* prepare a proposition symbol *)
let s = "WP_parameter " ^ name.id_string in
let lab = Ident.create_label ("expl:parameter " ^ name.id_string) in
......@@ -827,9 +904,12 @@ let add_wp_decl name f uc =
(* let f = bool_to_prop uc f in *)
let f = unabsurd f in
(* get a known map with tuples added *)
let km = Theory.get_known uc in
let lkm = Theory.get_known uc in
(* remove redundant invariants *)
let f = if Debug.test_flag no_track then f else track_values lkm km f in
(* simplify f *)
let f = Eval_match.eval_match ~inline:Eval_match.inline_nonrec_linear km f in
let f = if Debug.test_flag no_eval then f else
Eval_match.eval_match ~inline:Eval_match.inline_nonrec_linear lkm f in
(* printf "wp: f=%a@." print_term f; *)
let d = create_prop_decl Pgoal pr f in
Theory.add_decl uc d
......@@ -855,7 +935,7 @@ let wp_let env km th { let_sym = lv; let_expr = e } =
let id = match lv with
| LetV pv -> pv.pv_vs.vs_name
| LetA ps -> ps.ps_name in
add_wp_decl id f th
add_wp_decl km id f th
let wp_rec env km th rdl =
let env = mk_env env km th in
......@@ -864,7 +944,7 @@ let wp_rec env km th rdl =
Debug.dprintf debug "wp %s = %a@\n----------------@."
d.fun_ps.ps_name.id_string Pretty.print_term f;
let f = wp_forall (Mvs.keys f.t_vars) f in
add_wp_decl d.fun_ps.ps_name f th
add_wp_decl km d.fun_ps.ps_name f th
in
List.fold_left2 add_one th rdl.rec_defn fl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment