Commit f924e69e authored by Andrei Paskevich's avatar Andrei Paskevich

whyml: keep varmap in psymbols (needed for spec filtering)

parent a3bb9642
......@@ -40,7 +40,7 @@ module M
{ result = old !a + old !b }
let test_f5 () =
{ !x >= 0 } let f = f5 x in let b = ref 0 in f b { result = old !x }
{ !x >= 0 } let b = ref 0 in let f = f5 x in f b { result = old !x }
end
......
......@@ -32,6 +32,7 @@ open Mlw_ty.T
type psymbol = {
ps_name : ident;
ps_vta : vty_arrow;
ps_varm : varmap;
ps_vars : varset;
ps_subst : ity_subst;
}
......@@ -39,10 +40,11 @@ type psymbol = {
let ps_equal : psymbol -> psymbol -> bool = (==)
let create_psymbol_real ~poly id vta varm =
let vars = if poly then vars_empty else vta.vta_vars in
let vars = if poly then vars_empty else vta_vars vta in
let vars = vars_merge varm vars in
{ ps_name = id_register id;
ps_vta = vta_filter varm vta;
ps_varm = varm;
ps_vars = vars;
ps_subst = vars_freeze vars; }
......@@ -97,7 +99,6 @@ type let_sym =
type val_decl = {
val_sym : let_sym;
val_vty : vty;
val_vars : varmap;
}
type variant = {
......@@ -171,12 +172,12 @@ let create_val id vty = match vty with
| VTvalue v ->
let pv = create_pvsymbol id v in
vty_check vars_empty vty;
{ val_sym = LetV pv; val_vty = vty; val_vars = Mid.empty }
{ val_sym = LetV pv; val_vty = vty }
| VTarrow a ->
let varm = vta_varmap a in
let ps = create_psymbol_poly id a varm in
vta_check ps.ps_vars a;
{ val_sym = LetA ps; val_vty = vty; val_vars = varm }
{ val_sym = LetA ps; val_vty = vty }
(** patterns *)
......@@ -394,7 +395,6 @@ and let_defn = {
and rec_defn = {
rec_ps : psymbol;
rec_lambda : lambda;
rec_vars : varmap;
}
and lambda = {
......@@ -450,7 +450,7 @@ let vta_of_expr e = match e.e_vty with
let e_arrow ps vta =
let sbs = vta_vars_match ps.ps_subst ps.ps_vta vta in
let vars = Mid.singleton ps.ps_name ps.ps_vars in
let vars = Mid.add ps.ps_name ps.ps_vars ps.ps_varm in
let vta = vta_full_inst sbs ps.ps_vta in
mk_expr (Earrow ps) (VTarrow vta) eff_empty vars
......@@ -511,26 +511,27 @@ let e_app_real e pv =
let eff = eff_union e.e_effect spec.c_effect in
mk_expr (Eapp (e,pv,spec)) vty eff (add_pv_vars pv e.e_vars)
let create_fun_defn id lam =
let create_fun_defn id lam recsyms =
let e = lam.l_expr in
let spec = {
c_pre = lam.l_pre;
c_post = lam.l_post;
c_xpost = lam.l_xpost;
c_effect = e.e_effect; } in
let varm = spec_varmap e.e_vars lam.l_variant spec in
let varm = Mid.set_diff e.e_vars recsyms in
let varm = spec_varmap varm lam.l_variant spec in
let del_pv m pv = Mid.remove pv.pv_vs.vs_name m in
let varm = List.fold_left del_pv varm lam.l_args in
let vta = vty_arrow lam.l_args ~spec e.e_vty in
{ rec_ps = create_psymbol_poly id vta varm;
rec_lambda = lam;
rec_vars = varm; }
rec_lambda = lam; }
(* FIXME: if the given rdl is not the result of create_rec_defn,
the varmap calculation below might be off. We should probably
make [rec_defn list] a private type. *)
let e_rec rdl e =
let add_vars m rd = varmap_union m rd.rec_vars in
let remove_ps m rd = Mid.remove rd.rec_ps.ps_name m in
let add_vars m rd = varmap_union m rd.rec_ps.ps_varm in
let vars = List.fold_left add_vars e.e_vars rdl in
let vars = List.fold_left remove_ps vars rdl in
mk_expr (Erec (rdl,e)) e.e_vty e.e_effect vars
let on_value fn e = match e.e_node with
......@@ -847,10 +848,6 @@ let e_absurd ity =
(* Compute the fixpoint on recursive definitions *)
let vars_equal vs1 vs2 =
Stv.equal vs1.vars_tv vs2.vars_tv &&
Sreg.equal vs1.vars_reg vs2.vars_reg
let eff_equal eff1 eff2 =
Sreg.equal eff1.eff_reads eff2.eff_reads &&
Sreg.equal eff1.eff_writes eff2.eff_writes &&
......@@ -877,7 +874,7 @@ let rec vta_compat a1 a2 =
let ps_compat ps1 ps2 =
vta_compat ps1.ps_vta ps2.ps_vta &&
vars_equal ps1.ps_vars ps2.ps_vars
Mid.equal (fun _ _ -> true) ps1.ps_varm ps2.ps_varm
let rec expr_subst psm e = match e.e_node with
| Earrow ps when Mid.mem ps.ps_name psm ->
......@@ -922,8 +919,10 @@ let rec expr_subst psm e = match e.e_node with
| Elogic _ | Evalue _ | Earrow _ | Eany _ | Eabsurd | Eassert _ -> e
and create_rec_defn defl =
let add_sym acc (ps,_) = Sid.add ps.ps_name acc in
let recsyms = List.fold_left add_sym Sid.empty defl in
let conv m (ps,lam) =
let rd = create_fun_defn (id_clone ps.ps_name) lam in
let rd = create_fun_defn (id_clone ps.ps_name) lam recsyms in
if ps_compat ps rd.rec_ps then m, { rd with rec_ps = ps }
else Mid.add ps.ps_name rd.rec_ps m, rd in
let m, rdl = Util.map_fold_left conv Mid.empty defl in
......@@ -945,11 +944,13 @@ and subst_rd psm rdl =
is passed to create_rec_defn above which repeats substitution
until the effects are stabilized. TODO: prove correctness *)
let create_rec_defn defl =
let add_sym acc (ps,_) = Sid.add ps.ps_name acc in
let recsyms = List.fold_left add_sym Sid.empty defl in
let conv m (ps,lam) = match lam.l_expr.e_vty with
| VTarrow _ -> Loc.errorm ?loc:lam.l_expr.e_loc
"The body of a recursive function must be a first-order value"
| VTvalue _ ->
let rd = create_fun_defn (id_clone ps.ps_name) lam in
let rd = create_fun_defn (id_clone ps.ps_name) lam recsyms in
Mid.add ps.ps_name rd.rec_ps m, rd in
let m, rdl = Util.map_fold_left conv Mid.empty defl in
subst_rd m rdl
......@@ -957,7 +958,7 @@ let create_rec_defn defl =
let create_fun_defn id lam =
if lam.l_variant <> [] then
Loc.errorm "Variants are not allowed in a non-recursive definition";
create_fun_defn id lam
create_fun_defn id lam Sid.empty
(* fold *)
......
......@@ -36,6 +36,7 @@ open Mlw_ty.T
type psymbol = private {
ps_name : ident;
ps_vta : vty_arrow;
ps_varm : varmap;
ps_vars : varset;
(* this varset covers the type variables and regions of the defining
lambda that cannot be instantiated. Every other type variable
......@@ -84,7 +85,6 @@ type let_sym =
type val_decl = private {
val_sym : let_sym;
val_vty : vty;
val_vars : varmap;
}
val create_val : Ident.preid -> vty -> val_decl
......@@ -161,7 +161,6 @@ and let_defn = private {
and rec_defn = private {
rec_ps : psymbol;
rec_lambda : lambda;
rec_vars : varmap;
}
and lambda = {
......
......@@ -146,10 +146,10 @@ let print_psty fmt ps =
fprintf fmt "[%a]@ " (print_list comma print_tv) (Stv.elements tvs) in
let print_regs fmt regs = if not (Sreg.is_empty regs) then
fprintf fmt "<%a>@ " (print_list comma print_regty) (Sreg.elements regs) in
let vars = ps.ps_vta.vta_vars in
let vars = vta_vars ps.ps_vta in
fprintf fmt "@[%a :@ %a%a%a@]"
print_ps ps
print_tvs (Stv.diff vars.vars_tv ps.ps_vars.vars_tv)
print_tvs (Mtv.set_diff vars.vars_tv ps.ps_subst.ity_subst_tv)
print_regs (Mreg.set_diff vars.vars_reg ps.ps_subst.ity_subst_reg)
print_vta ps.ps_vta
......
......@@ -764,15 +764,15 @@ and vty_arrow = {
vta_result : vty;
vta_spec : spec;
vta_ghost : bool;
vta_vars : varset;
(* this varset covers every type variable and region in vta_arg
and vta_result, but may skip some type variables and regions
in vta_effect *)
}
let vty_vars = function
let rec vta_vars vta =
let add_arg vars pv = vars_union vars pv.pv_vtv.vtv_vars in
List.fold_left add_arg (vty_vars vta.vta_result) vta.vta_args
and vty_vars = function
| VTvalue vtv -> vtv.vtv_vars
| VTarrow vta -> vta.vta_vars
| VTarrow vta -> vta_vars vta
let vty_ghost = function
| VTvalue vtv -> vtv.vtv_ghost
......@@ -788,14 +788,12 @@ let ty_of_vty = function
let spec_check spec vty = spec_check spec (ty_of_vty vty)
let vty_arrow_unsafe argl ~spec ~ghost vty =
let add_arg vars { pv_vtv = vtv } = vars_union vars vtv.vtv_vars in
{ vta_args = argl;
vta_result = vty;
vta_spec = spec;
vta_ghost = ghost || vty_ghost vty;
vta_vars = List.fold_left add_arg (vty_vars vty) argl;
}
let vty_arrow_unsafe argl ~spec ~ghost vty = {
vta_args = argl;
vta_result = vty;
vta_spec = spec;
vta_ghost = ghost || vty_ghost vty;
}
let vty_arrow argl ?spec ?(ghost=false) vty =
(* we accept a mutable vty_value as a result to simplify Mlw_expr,
......
......@@ -286,9 +286,6 @@ and vty_arrow = private {
vta_result : vty;
vta_spec : spec;
vta_ghost : bool;
vta_vars : varset;
(* this varset covers every type variable and region in vta_arg
and vta_result, but not necessarily in vta_spec *)
}
exception UnboundException of xsymbol
......@@ -297,17 +294,17 @@ exception UnboundException of xsymbol
val vty_arrow : pvsymbol list -> ?spec:spec -> ?ghost:bool -> vty -> vty_arrow
(* this only compares the types of arguments and results, and ignores
the spec. In other words, only the type variables and regions
in .vta_vars are matched. The caller should supply a "freezing"
the spec. In other words, only the type variables and regions in
[vta_vars vta] are matched. The caller should supply a "freezing"
substitution that covers all external type variables and regions. *)
val vta_vars_match : ity_subst -> vty_arrow -> vty_arrow -> ity_subst
(* the substitution must cover not only vta_vars but also every
type variable and every region in vta_spec *)
(* the substitution must cover not only [vta_vars vta] but
also every type variable and every region in vta_spec *)
val vta_full_inst : ity_subst -> vty_arrow -> vty_arrow
(* remove from the given arrow every effect that is covered
neither by the arrow's vta_vars nor by the given varmap *)
neither by the arrow's arguments nor by the given varmap *)
val vta_filter : varmap -> vty_arrow -> vty_arrow
(* apply a function specification to a variable argument *)
......@@ -324,4 +321,7 @@ val spec_check : spec -> vty -> unit
val ity_of_vty : vty -> ity
val ty_of_vty : vty -> ty
(* collects the type variables and regions in arguments and values,
but ignores the spec *)
val vta_vars : vty_arrow -> varset
val vty_vars : vty -> varset
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment