Commit 0373bd66 authored by Andrei Paskevich's avatar Andrei Paskevich

whyml: type assertions and strong updates

parent 4ffbe564
......@@ -37,6 +37,7 @@ end)
module Svs = Vsym.S
module Mvs = Vsym.M
module Hvs = Vsym.H
module Wvs = Vsym.W
let vs_equal : vsymbol -> vsymbol -> bool = (==)
......
......@@ -34,6 +34,7 @@ type vsymbol = private {
module Mvs : Map.S with type key = vsymbol
module Svs : Mvs.Set
module Hvs : Hashtbl.S with type key = vsymbol
module Wvs : Hashweak.S with type key = vsymbol
val vs_equal : vsymbol -> vsymbol -> bool
val vs_hash : vsymbol -> int
......
......@@ -148,10 +148,10 @@ let rec occur_check_reg tv = function
| Dts (_,dl) ->
List.iter (occur_check_reg tv) dl
let rec unify d1 d2 = match d1,d2 with
let rec unify ~weak d1 d2 = match d1,d2 with
| Dvar { contents = Dval d1 }, d2
| d1, Dvar { contents = Dval d2 } ->
unify d1 d2
unify ~weak d1 d2
| Dvar { contents = Dtvs tv1 },
Dvar { contents = Dtvs tv2 } when tv_equal tv1 tv2 ->
()
......@@ -162,11 +162,11 @@ let rec unify d1 d2 = match d1,d2 with
| Dits (its1, dl1, rl1), Dits (its2, dl2, rl2) when its_equal its1 its2 ->
assert (List.length rl1 = List.length rl2);
assert (List.length dl1 = List.length dl2);
List.iter2 unify dl1 dl2;
List.iter2 unify_reg rl1 rl2
List.iter2 (unify ~weak) dl1 dl2;
if not weak then List.iter2 unify_reg rl1 rl2
| Dts (ts1, dl1), Dts (ts2, dl2) when ts_equal ts1 ts2 ->
assert (List.length dl1 = List.length dl2);
List.iter2 unify dl1 dl2
List.iter2 (unify ~weak) dl1 dl2
| _ -> raise Exit
and unify_reg r1 r2 =
......@@ -186,14 +186,18 @@ and unify_reg r1 r2 =
| d, Rvar ({ contents = Rtvs (tv,rd,_) } as r) ->
let dity = dity_of_reg d in
occur_check_reg tv dity;
unify rd dity;
unify ~weak:false rd dity;
r := Rval d
| Rureg (tv1,_,_), Rureg (tv2,_,_) when tv_equal tv1 tv2 -> ()
| Rreg (reg1,_), Rreg (reg2,_) when reg_equal reg1 reg2 -> ()
| _ -> raise Exit
let unify_weak d1 d2 =
try unify ~weak:true d1 d2
with Exit -> raise (TypeMismatch (ity_of_dity d1, ity_of_dity d2))
let unify d1 d2 =
try unify d1 d2
try unify ~weak:false d1 d2
with Exit -> raise (TypeMismatch (ity_of_dity d1, ity_of_dity d2))
let ts_arrow =
......
......@@ -46,6 +46,9 @@ val make_arrow_type: dity list -> dity -> dity
val unify: dity -> dity -> unit
(** destructive unification *)
val unify_weak: dity -> dity -> unit
(** destructive unification, ignores regions *)
val ity_of_dity: dity -> ity
val vty_of_dity: dity -> vty
(** use with care, only once unification is done *)
......
......@@ -40,6 +40,15 @@ let create_pvsymbol id vtv = {
pv_vtv = vtv;
}
let create_pvsymbol, restore_pv =
let vs_to_pv = Wvs.create 17 in
(fun id vtv ->
let pv = create_pvsymbol id vtv in
Wvs.set vs_to_pv pv.pv_vs pv;
pv),
(fun vs -> try Wvs.find vs_to_pv vs with
Not_found -> raise (Decl.UnboundVar vs))
(** program symbols *)
type psymbol = {
......@@ -362,9 +371,12 @@ let mk_expr node vty eff vars = {
}
let varmap_join = Mid.fold (fun _ -> vars_union)
let varmap_union = Mid.union (fun _ s1 s2 -> Some (vars_union s1 s2))
let varmap_union = Mid.set_union
let add_pv_vars pv m = Mid.add pv.pv_vs.vs_name pv.pv_vtv.vtv_vars m
let add_vs_vars vs m = add_pv_vars (restore_pv vs) m
let add_t_vars t m = Mvs.fold (fun vs _ m -> add_vs_vars vs m) t.t_vars m
let add_e_vars e m = varmap_union e.e_vars m
let e_value pv =
......@@ -481,10 +493,9 @@ let create_fun_defn id lam =
let vsset = Mexn.fold (fun _ -> add_term) lam.l_xpost vsset in
let vsset =
List.fold_right (fun v -> add_term v.v_term) lam.l_variant vsset in
let add_vs vs _ m = Mid.add vs.vs_name (vs_vars vars_empty vs) m in
let add_vs vs _ m = add_vs_vars vs m in
let del_pv m pv = Mid.remove pv.pv_vs.vs_name m in
let recvars = Mvs.fold add_vs vsset Mid.empty in
let recvars = add_e_vars lam.l_expr recvars in
let recvars = Mvs.fold add_vs vsset lam.l_expr.e_vars in
let recvars = List.fold_left del_pv recvars lam.l_args in
let vars = varmap_join recvars vars_empty in
(* compute rec_ps.ps_vta *)
......@@ -793,9 +804,9 @@ let e_absurd ity =
mk_expr Eabsurd vty eff_empty Mid.empty
let e_assert ak f =
let eff, vars = assert false (*TODO*) in
let vars = add_t_vars f Mid.empty in
let vty = VTvalue (vty_value ity_unit) in
mk_expr (Eassert (ak, f)) vty eff vars
mk_expr (Eassert (ak, f)) vty eff_empty vars
(* Compute the fixpoint on recursive definitions *)
......
......@@ -160,6 +160,11 @@ let print_ppat fmt ppat = print_pat fmt ppat.ppat_pattern
(* expressions *)
let print_ak fmt = function
| Aassert -> fprintf fmt "assert"
| Aassume -> fprintf fmt "assume"
| Acheck -> fprintf fmt "check"
let print_list_next sep print fmt = function
| [] -> ()
| [x] -> print true fmt x
......@@ -243,6 +248,10 @@ and print_enode pri fmt e = match e.e_node with
| Etry (e,bl) ->
fprintf fmt "try %a with@\n@[<hov>%a@]@\nend"
print_expr e (print_list newline print_xbranch) bl
| Eabsurd ->
fprintf fmt "absurd"
| Eassert (ak,f) ->
fprintf fmt "%a@ (%a)" print_ak ak print_term f
| _ ->
fprintf fmt "<expr TODO>"
......
......@@ -126,8 +126,6 @@ let vars_union s1 s2 = {
vars_reg = Sreg.union s1.vars_reg s2.vars_reg;
}
let vs_vars s vs = { s with vars_tv = ty_freevars s.vars_tv vs.vs_ty }
let create_varset tvs regs = {
vars_tv = Sreg.fold (fun r -> Stv.union r.reg_ity.ity_vars.vars_tv) regs tvs;
vars_reg = regs;
......
......@@ -168,8 +168,6 @@ val vars_union : varset -> varset -> varset
val vars_freeze : varset -> ity_subst
val vs_vars : varset -> vsymbol -> varset
val create_varset : Stv.t -> Sreg.t -> varset
(* exception symbols *)
......
......@@ -136,6 +136,9 @@ let dity_unit = ts_app (ts_tuple 0) []
let expected_type e dity =
unify e.dexpr_type dity
let expected_type_weak e dity =
unify_weak e.dexpr_type dity
let rec extract_labels labs loc e = match e.Ptree.expr_desc with
| Ptree.Enamed (Ptree.Lstr s, e) -> extract_labels (s :: labs) loc e
| Ptree.Enamed (Ptree.Lpos p, e) -> extract_labels labs (Some p) e
......@@ -413,7 +416,7 @@ and dexpr_desc denv loc = function
let e1 = { expr_desc = Eapply (fl,e1); expr_loc = loc } in
let e1 = dexpr denv e1 in
let e2 = dexpr denv e2 in
expected_type e2 e1.dexpr_type;
expected_type_weak e2 e1.dexpr_type;
DEassign (e1, e2), dity_unit
| Ptree.Econstant (ConstInt _ as c) ->
DEconstant c, dity_int
......
......@@ -27,6 +27,10 @@ module N
exception Exit (tree int)
type dref 'a = {| mutable dcontents : ref 'a |}
let create_dref i = {| dcontents = {| contents = i |} |}
let myfun r =
let rec on_tree t = match t with
| Node {| contents = v |} f -> v + on_forest f
......@@ -36,6 +40,11 @@ module N
| Nil -> 1
end
in
let dr = create_dref 0 in
let or = dr.dcontents in
let nr = {| contents = 1 |} in
dr.dcontents <- nr;
assert { r = r };
try on_tree r with Exit -> 0 end
end
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment