Commit 4f2fd028 authored by Andrei Paskevich's avatar Andrei Paskevich

whyml: put type invariants in specifications

parent ce6e3e43
......@@ -69,6 +69,7 @@ end)
module Sid = Id.S
module Mid = Id.M
module Hid = Id.H
module Wid = Id.W
type preid = ident
......
......@@ -49,6 +49,7 @@ type ident = private {
module Mid : Map.S with type key = ident
module Sid : Mid.Set
module Hid : Hashtbl.S with type key = ident
module Wid : Hashweak.S with type key = ident
val id_equal : ident -> ident -> bool
......
......@@ -366,12 +366,16 @@ let variant_vars varl vsset =
let add_variant s (t,_) = Mvs.set_union s t.t_vars in
List.fold_left add_variant vsset varl
let spec_varmap varm spec =
let spec_vsset spec =
let vsset = pre_vars spec.c_pre Mvs.empty in
let vsset = post_vars spec.c_post vsset in
let vsset = xpost_vars spec.c_xpost vsset in
let vsset = variant_vars spec.c_variant vsset in
add_t_vars vsset varm
vsset
let spec_varmap varm spec = add_t_vars (spec_vsset spec) varm
let spec_vsset spec = Mvs.map (const ()) (spec_vsset spec)
let rec vta_varmap vta =
let varm = match vta.vta_result with
......
......@@ -120,6 +120,8 @@ val create_psymbol : preid -> vty_arrow -> psymbol
val create_psymbol_extra : preid -> vty_arrow -> Spv.t -> Sps.t -> psymbol
val spec_vsset : spec -> Svs.t
(** program expressions *)
type assertion_kind = Aassert | Aassume | Acheck
......
......@@ -827,14 +827,14 @@ let create_pvsymbol id vtv = {
pv_vars = vtv_vars vtv;
}
let create_pvsymbol, restore_pv =
let vs_to_pv = Wvs.create 17 in
let create_pvsymbol, restore_pv, restore_pv_by_id =
let id_to_pv = Wid.create 17 in
(fun id vtv ->
let pv = create_pvsymbol id vtv in
Wvs.set vs_to_pv pv.pv_vs pv;
Wid.set id_to_pv pv.pv_vs.vs_name pv;
pv),
(fun vs -> try Wvs.find vs_to_pv vs with Not_found ->
Loc.error ?loc:vs.vs_name.id_loc (Decl.UnboundVar vs))
(fun vs -> Wid.find id_to_pv vs.vs_name),
(fun id -> Wid.find id_to_pv id)
(** program types *)
......
......@@ -285,7 +285,10 @@ val pv_equal : pvsymbol -> pvsymbol -> bool
val create_pvsymbol : preid -> vty_value -> pvsymbol
val restore_pv : vsymbol -> pvsymbol
(* raises Decl.UnboundVar if the argument is not a pv_vs *)
(* raises Not_found if the argument is not a pv_vs *)
val restore_pv_by_id : ident -> pvsymbol
(* raises Not_found if the argument is not a pv_vs.vs_name *)
(** program types *)
......
......@@ -659,6 +659,79 @@ let create_lenv uc = {
log_denv = Typing.denv_empty_with_globals (find_global_vs uc);
}
(* invariant handling *)
let env_invariant lenv svs =
let kn = get_known lenv.mod_uc in
let lkn = Theory.get_known (get_theory lenv.mod_uc) in
let add_vs vs inv =
let ity = (restore_pv vs).pv_vtv.vtv_ity in
t_and_simp inv (Mlw_wp.full_invariant lkn kn vs ity) in
Svs.fold add_vs svs t_true
let post_invariant lenv inv ity q =
let vs, q = open_post q in
let kn = get_known lenv.mod_uc in
let lkn = Theory.get_known (get_theory lenv.mod_uc) in
let res_inv = Mlw_wp.full_invariant lkn kn vs ity in
let q = t_and_asym_simp q (t_and_simp res_inv inv) in
Mlw_ty.create_post vs q
let spec_invariant lenv svs ity spec =
let inv = env_invariant lenv svs in
let post_inv = post_invariant lenv inv in
let xpost_inv xs q = post_inv xs.xs_ity q in
{ spec with c_pre = t_and_simp spec.c_pre inv;
c_post = post_inv ity spec.c_post;
c_xpost = Mexn.mapi xpost_inv spec.c_xpost }
let ity_or_unit = function
| VTvalue v -> v.vtv_ity
| VTarrow _ -> ity_unit
let expr_vsset svs e =
let add_id id _ s =
try Svs.add (restore_pv_by_id id).pv_vs s
with Not_found -> s in
Mid.fold add_id e.e_varm svs
let abst_invariant lenv e q xq =
let spec = {
c_pre = t_true;
c_effect = eff_empty;
c_post = q;
c_xpost = xq;
c_variant = [];
c_letrec = 0 } in
let ity = ity_or_unit e.e_vty in
let svs = expr_vsset (spec_vsset spec) e in
let spec = spec_invariant lenv svs ity spec in
spec.c_post, spec.c_xpost
let spec_of_lambda lam = {
c_pre = lam.l_pre;
c_effect = lam.l_expr.e_effect;
c_post = lam.l_post;
c_xpost = lam.l_xpost;
c_variant = lam.l_variant;
c_letrec = 0 }
let lambda_invariant lenv svs lam =
let spec = spec_of_lambda lam in
let add_pv s pv = Svs.add pv.pv_vs s in
let svs = List.fold_left add_pv svs lam.l_args in
let ity = ity_or_unit lam.l_expr.e_vty in
let spec = spec_invariant lenv svs ity spec in
{ lam with l_pre = spec.c_pre;
l_post = spec.c_post;
l_xpost = spec.c_xpost }
let lambda_vsset lam =
let del_pv svs pv = Svs.remove pv.pv_vs svs in
let svs = spec_vsset (spec_of_lambda lam) in
let svs = expr_vsset svs lam.l_expr in
List.fold_left del_pv svs lam.l_args
let rec dty_of_ty ty = match ty.ty_node with
| Ty.Tyapp (ts, tyl) -> Denv.tyapp ts (List.map dty_of_ty tyl)
| Ty.Tyvar v -> Denv.tyuvar v
......@@ -842,7 +915,11 @@ let rec type_c lenv gh svs vars dtyc =
let add vs f = let t = t_var vs in t_and_simp (t_equ t t) f in
let xq = Mlw_ty.create_post res (Svs.fold add esvs t_true) in
Mexn.add_new exn xs_exit xq spec.c_xpost in
{ spec with c_xpost = xpost }, vty
let spec = { spec with c_xpost = xpost } in
(* add the invariants *)
let ity = ity_or_unit vty in
let svs = Svs.union svs (spec_vsset spec) in
spec_invariant lenv svs ity spec, vty
and type_v lenv gh svs vars = function
| DSpecV (ghost,v) ->
......@@ -971,7 +1048,9 @@ and expr_desc lenv loc de = match de.de_desc with
| DEabstract (de1, q, xq) ->
let e1 = expr lenv de1 in
let q = create_post lenv "result" e1.e_vty q in
e_abstract e1 q (complete_xpost lenv e1.e_effect xq)
let xq = complete_xpost lenv e1.e_effect xq in
let q, xq = abst_invariant lenv e1 q xq in
e_abstract e1 q xq
| DEassert (ak, f) ->
let ak = match ak with
| Ptree.Aassert -> Aassert
......@@ -1033,10 +1112,15 @@ and expr_rec lenv rdl =
add_local id.id (LetA ps) lenv, (ps, gh, lam) in
let lenv, rdl = Util.map_fold_left step1 lenv rdl in
let step2 (ps, gh, lam) = ps, expr_lam lenv gh lam in
create_rec_defn (List.map step2 rdl)
let rdl = List.map step2 rdl in
let add_rd_vsset s (_, lam) = Svs.union s (lambda_vsset lam) in
let svs = List.fold_left add_rd_vsset Svs.empty rdl in
let step3 (ps, lam) = ps, lambda_invariant lenv svs lam in
create_rec_defn (List.map step3 rdl)
and expr_fun lenv x gh lam =
let lam = expr_lam lenv gh lam in
let lam = lambda_invariant lenv (lambda_vsset lam) lam in
let def = create_fun_defn (Denv.create_user_id x) lam in
def, (List.hd def.rec_defn).fun_ps
......
......@@ -80,7 +80,7 @@ let to_term t = if t.t_ty = None then mk_t_if t else t
(* any vs in post/xpost is either a pvsymbol or a fresh mark *)
let vtv_of_vs vs =
try (restore_pv vs).pv_vtv with UnboundVar _ -> vtv_mark
try (restore_pv vs).pv_vtv with Not_found -> vtv_mark
(* replace every occurrence of [old(t)] with [at(t,'old)] *)
let rec remove_old f = match f.t_node with
......@@ -128,6 +128,7 @@ let expl_post = Ident.create_label "expl:normal postcondition"
let expl_xpost = Ident.create_label "expl:exceptional postcondition"
let expl_assert = Ident.create_label "expl:assertion"
let expl_check = Ident.create_label "expl:check"
let expl_inv = Ident.create_label "expl:type invariant"
let expl_variant = Ident.create_label "expl:variant decreases"
let expl_loop_init = Ident.create_label "expl:loop invariant init"
let expl_loop_keep = Ident.create_label "expl:loop invariant preservation"
......@@ -256,11 +257,11 @@ let decrease ?loc env olds varl =
(** Reconstruct pure values after writes *)
let find_constructors env sts ity = match ity.ity_node with
let find_constructors lkn kn sts ity = match ity.ity_node with
| Itypur (ts,_) ->
let base = ity_pur ts (List.map ity_var ts.ts_args) in
let sbs = ity_match ity_subst_empty base ity in
let csl = Decl.find_constructors env.pure_known ts in
let csl = Decl.find_constructors lkn ts in
if csl = [] || Sts.mem ts sts then Loc.errorm
"Cannot update values of type %a" Mlw_pretty.print_ity base;
let subst ty = ity_full_inst sbs (ity_of_ty ty), None in
......@@ -269,7 +270,7 @@ let find_constructors env sts ity = match ity.ity_node with
| Ityapp (its,_,_) ->
let base = ity_app its (List.map ity_var its.its_args) its.its_regs in
let sbs = ity_match ity_subst_empty base ity in
let csl = Mlw_decl.find_constructors env.prog_known its in
let csl = Mlw_decl.find_constructors kn its in
if csl = [] || Sts.mem its.its_pure sts then Loc.errorm
"Cannot update values of type %a" Mlw_pretty.print_ity base;
let subst vtv =
......@@ -279,6 +280,17 @@ let find_constructors env sts ity = match ity.ity_node with
Sts.add its.its_pure sts, List.map cnstr csl
| Ityvar _ -> assert false
let analyze_var fn_down fn_join lkn kn sts vs ity =
let sts, csl = find_constructors lkn kn sts ity in
let branch (cs,ityl) =
let mk_var (ity,_) = create_vsymbol (id_fresh "y") (ty_of_ity ity) in
let vars = List.map mk_var ityl in
let mk_arg vs (ity, mut) = fn_down sts vs ity mut in
let t = fn_join cs (List.map2 mk_arg vars ityl) vs.vs_ty in
let pat = pat_app cs (List.map pat_var vars) vs.vs_ty in
t_close_branch pat t in
t_case (t_var vs) (List.map branch csl)
let update_var env mreg vs =
let rec update sts vs ity mut =
(* are we a mutable variable? *)
......@@ -286,18 +298,8 @@ let update_var env mreg vs =
let vs = Util.option_apply vs get_vs mut in
(* should we update our value further? *)
let check_reg r _ = reg_occurs r ity.ity_vars in
if ity_pure ity || not (Mreg.exists check_reg mreg) then
t_var vs
else
let sts, csl = find_constructors env sts ity in
let branch (cs,ityl) =
let mk_var (ity,_) = create_vsymbol (id_fresh "y") (ty_of_ity ity) in
let vars = List.map mk_var ityl in
let pat = pat_app cs (List.map pat_var vars) vs.vs_ty in
let mk_arg vs (ity, mut) = update sts vs ity mut in
let t = fs_app cs (List.map2 mk_arg vars ityl) vs.vs_ty in
t_close_branch pat t in
t_case (t_var vs) (List.map branch csl)
if ity_pure ity || not (Mreg.exists check_reg mreg) then t_var vs
else analyze_var update fs_app env.pure_known env.prog_known sts vs ity
in
let vtv = vtv_of_vs vs in
update Sts.empty vs vtv.vtv_ity vtv.vtv_mut
......@@ -366,6 +368,44 @@ let quantify env regs f =
let f = Mvs.fold update vars (subst_at_now true vv' f) in
wp_forall (List.rev (Mreg.values mreg)) f
(* invariants *)
let get_invariant kn v =
let ts = match v.vs_ty.ty_node with
| Tyapp (ts,_) -> ts
| _ -> assert false in
let rec find_td = function
| (its,_,inv) :: _ when ts_equal ts its.its_pure -> inv
| _ :: tdl -> find_td tdl
| [] -> assert false in
let pd = Mid.find ts.ts_name kn in
let inv = match pd.Mlw_decl.pd_node with
| Mlw_decl.PDdata tdl -> find_td tdl
| _ -> assert false in
let sbs = Ty.ty_match Mtv.empty (t_type inv) v.vs_ty in
let u, p = open_post (t_ty_subst sbs Mvs.empty inv) in
wp_expl expl_inv (t_subst_single u (t_var v) p)
let ps_inv = Term.create_psymbol (id_fresh "inv")
[ty_var (create_tvsymbol (id_fresh "a"))]
let full_invariant lkn kn vs ity =
let rec update sts vs ity _ =
if not (ity_inv ity) then t_true else
(* what is our current invariant? *)
let f = match ity.ity_node with
| Ityapp (its,_,_) when its.its_inv ->
get_invariant kn vs
(* ps_app ps_inv [t_var vs] *)
| _ -> t_true in
(* what are our sub-invariants? *)
let join _ fl _ = wp_ands ~sym:true fl in
let g = analyze_var update join lkn kn sts vs ity in
(* put everything together *)
wp_and ~sym:true f g
in
update Sts.empty vs ity None
(** Weakest preconditions *)
let rec wp_expr env e q xq =
......
......@@ -40,6 +40,9 @@ val e_now : expr
val remove_old : Term.term -> Term.term
val full_invariant :
Decl.known_map -> Mlw_decl.known_map -> Term.vsymbol -> ity -> Term.term
(** Weakest preconditions *)
val wp_val: Env.env -> known_map -> theory_uc -> let_sym -> theory_uc
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment