Commit c0696892 authored by Andrei Paskevich's avatar Andrei Paskevich

whyml: back to the KISS principle

Merge specifications into program types, as JCF intended.
parent 14c2c552
......@@ -178,7 +178,7 @@ let letvar_news = function
| LetA ps -> check_vars ps.ps_vars; Sid.singleton ps.ps_name
let create_let_decl ld =
let news = letvar_news ld.let_var in
let news = letvar_news ld.let_sym in
(*
let syms = syms_varmap Sid.empty ld.let_expr.e_vars in
let syms = syms_effect syms ld.let_expr.e_effect in
......@@ -207,7 +207,7 @@ let create_rec_decl rdl =
mk_decl (PDrec rdl) (*syms*) news
let create_val_decl vd =
let news = letvar_news vd.val_name in
let news = letvar_news vd.val_sym in
(*
let syms = syms_type_v Sid.empty vd.val_spec in
let syms = syms_varmap syms vd.val_vars in
......
......@@ -202,8 +202,9 @@ let unify d1 d2 = unify ~weak:false d1 d2
type dvty = dity list * dity (* A -> B -> C == ([A;B],C) *)
let vty_of_dvty (argl,res) =
let add a v = VTarrow (vty_arrow (vty_value (ity_of_dity a)) v) in
List.fold_right add argl (VTvalue (vty_value (ity_of_dity res)))
let vtv = VTvalue (vty_value (ity_of_dity res)) in
let conv a = create_pvsymbol (id_fresh "x") (vty_value (ity_of_dity a)) in
if argl = [] then vtv else VTarrow (vty_arrow (List.map conv argl) vtv)
type tvars = dity list
......@@ -284,14 +285,14 @@ let specialize_xsymbol xs =
let specialize_vtarrow vars vta =
let htv = Htv.create 3 and hreg = Hreg.create 3 in
let conv vtv = dity_of_vtv htv hreg vars vtv in
let conv pv = dity_of_vtv htv hreg vars pv.pv_vtv in
let rec specialize a =
let arg = conv a.vta_arg in
let argl,res = match a.vta_result with
| VTvalue v -> [], conv v
let argl = List.map conv a.vta_args in
let narg,res = match a.vta_result with
| VTvalue v -> [], dity_of_vtv htv hreg vars v
| VTarrow a -> specialize a
in
arg::argl, res
argl @ narg, res
in
specialize vta
......
This diff is collapsed.
......@@ -27,26 +27,6 @@ open Term
open Mlw_ty
open Mlw_ty.T
(** program variables *)
(* pvsymbols represent function arguments and pattern variables *)
type pvsymbol = private {
pv_vs : vsymbol;
pv_vtv : vty_value;
}
module Mpv : Map.S with type key = pvsymbol
module Spv : Mpv.Set
module Hpv : Hashtbl.S with type key = pvsymbol
module Wpv : Hashweak.S with type key = pvsymbol
val pv_equal : pvsymbol -> pvsymbol -> bool
val create_pvsymbol : preid -> vty_value -> pvsymbol
val restore_pv : vsymbol -> pvsymbol
(** program symbols *)
(* psymbols represent lambda-abstractions. They are polymorphic and
......@@ -59,17 +39,12 @@ type psymbol = private {
ps_vars : varset;
(* this varset covers the type variables and regions of the defining
lambda that cannot be instantiated. Every other type variable
and region in ps_vty is generalized and can be instantiated. *)
and region in ps_vta is generalized and can be instantiated. *)
ps_subst : ity_subst;
(* this substitution instantiates every type variable and region
in ps_vars to itself *)
}
module Mps : Map.S with type key = psymbol
module Sps : Mps.Set
module Hps : Hashtbl.S with type key = psymbol
module Wps : Hashweak.S with type key = psymbol
val ps_equal : psymbol -> psymbol -> bool
val create_psymbol : preid -> vty_arrow -> psymbol
......@@ -102,39 +77,17 @@ exception HiddenPLS of lsymbol
(** specification *)
type pre = term (* precondition: pre_fmla *)
type post = term (* postcondition: eps result . post_fmla *)
type xpost = post Mexn.t (* exceptional postconditions *)
val create_post : vsymbol -> term -> post
val open_post : post -> vsymbol * term
type type_c = {
c_pre : pre;
c_effect : effect;
c_result : type_v;
c_post : post;
c_xpost : xpost;
}
and type_v =
| SpecV of vty_value
| SpecA of pvsymbol list * type_c
type let_var =
type let_sym =
| LetV of pvsymbol
| LetA of psymbol
type val_decl = private {
val_name : let_var;
val_spec : type_v;
val_vars : varset Mid.t;
val_sym : let_sym;
val_vty : vty;
val_vars : varmap;
}
val create_val : Ident.preid -> type_v -> val_decl
exception DuplicateArg of pvsymbol
exception UnboundException of xsymbol
val create_val : Ident.preid -> vty -> val_decl
(** patterns *)
......@@ -175,24 +128,23 @@ type expr = private {
e_node : expr_node;
e_vty : vty;
e_effect : effect;
e_vars : varset Mid.t;
e_vars : varmap;
e_label : Slab.t;
e_loc : Loc.position option;
e_tag : Hashweak.tag;
}
and expr_node = private
| Elogic of term
| Evalue of pvsymbol
| Earrow of psymbol
| Eapp of expr * pvsymbol
| Eapp of expr * pvsymbol * spec
| Elet of let_defn * expr
| Erec of rec_defn list * expr
| Eif of expr * expr * expr
| Ecase of expr * (ppattern * expr) list
| Eassign of expr * region * pvsymbol
| Eghost of expr
| Eany of type_c
| Eany of spec
| Eloop of invariant * variant list * expr
| Efor of pvsymbol * for_bounds * invariant * expr
| Eraise of xsymbol * expr
......@@ -202,14 +154,14 @@ and expr_node = private
| Eabsurd
and let_defn = private {
let_var : let_var;
let_sym : let_sym;
let_expr : expr;
}
and rec_defn = private {
rec_ps : psymbol;
rec_lambda : lambda;
rec_vars : varset Mid.t;
rec_vars : varmap;
}
and lambda = {
......@@ -226,11 +178,6 @@ and variant = {
v_rel : lsymbol option; (* tau tau : prop *)
}
module Mexpr : Map.S with type key = expr
module Sexpr : Mexpr.Set
module Hexpr : Hashtbl.S with type key = expr
module Wexpr : Hashweak.S with type key = expr
val e_label : ?loc:Loc.position -> Slab.t -> expr -> expr
val e_label_add : label -> expr -> expr
val e_label_copy : expr -> expr -> expr
......@@ -269,7 +216,7 @@ exception Immutable of expr
val e_assign : expr -> expr -> expr
val e_ghost : expr -> expr
val e_any : type_c -> expr
val e_any : spec -> vty -> expr
val e_void : expr
......
......@@ -311,7 +311,7 @@ let add_pdecl uc d =
let defn cl = List.map constructor cl in
let dl = List.map (fun (its,cl) -> its.its_pure, defn cl) dl in
add_to_theory Theory.add_data_decl uc dl
| PDval { val_name = lv } | PDlet { let_var = lv } ->
| PDval { val_sym = lv } | PDlet { let_sym = lv } ->
add_let uc lv
| PDrec rdl ->
List.fold_left add_rec uc rdl
......
......@@ -130,8 +130,9 @@ let print_vtv fmt vtv =
fprintf fmt "%s%a" (if vtv.vtv_ghost then "?" else "") print_ity vtv.vtv_ity
let rec print_vta fmt vta =
fprintf fmt "%a ->@ %a%a" print_vtv vta.vta_arg
print_effect vta.vta_effect print_vty vta.vta_result
let print_arg fmt pv = fprintf fmt "%a ->@ " print_vtv pv.pv_vtv in
fprintf fmt "%a%a%a" (print_list nothing print_arg) vta.vta_args
print_effect vta.vta_spec.c_effect print_vty vta.vta_result
and print_vty fmt = function
| VTarrow vta -> print_vta fmt vta
......@@ -168,18 +169,20 @@ let forget_lv = function
| LetA ps -> forget_ps ps
let rec print_type_v fmt = function
| SpecV vtv -> print_vtv fmt vtv
| SpecA (pvl,tyc) ->
| VTvalue vtv -> print_vtv fmt vtv
| VTarrow vta ->
let print_arg fmt pv = fprintf fmt "@[(%a)@] ->@ " print_pvty pv in
fprintf fmt "%a%a" (print_list nothing print_arg) pvl print_type_c tyc;
List.iter forget_pv pvl
fprintf fmt "%a%a"
(print_list nothing print_arg) vta.vta_args
(print_type_c vta.vta_spec) vta.vta_result;
List.iter forget_pv vta.vta_args
and print_type_c fmt tyc =
and print_type_c spec fmt vty =
fprintf fmt "{ %a }@ %a%a@ { %a }"
print_term tyc.c_pre
print_effect tyc.c_effect
print_type_v tyc.c_result
print_post tyc.c_post
print_term spec.c_pre
print_effect spec.c_effect
print_type_v vty
print_post spec.c_post
(* TODO: print_xpost *)
let print_invariant fmt f =
......@@ -262,14 +265,14 @@ and print_enode pri fmt e = match e.e_node with
print_pv fmt v
| Earrow a ->
print_ps fmt a
| Eapp (e,v) ->
| Eapp (e,v,_) ->
fprintf fmt "(%a@ %a)" (print_lexpr pri) e print_pv v
| Elet ({ let_var = LetV pv ; let_expr = e1 }, e2)
| Elet ({ let_sym = LetV pv ; let_expr = e1 }, e2)
when pv.pv_vs.vs_name.id_string = "_" &&
ity_equal pv.pv_vtv.vtv_ity ity_unit ->
fprintf fmt (protect_on (pri > 0) "%a;@\n%a")
print_expr e1 print_expr e2;
| Elet ({ let_var = lv ; let_expr = e1 }, e2) ->
| Elet ({ let_sym = lv ; let_expr = e1 }, e2) ->
fprintf fmt (protect_on (pri > 0) "@[<hov 2>let %a =@ %a@ in@]@\n%a")
print_lv lv (print_lexpr 4) e1 print_expr e2;
forget_lv lv
......@@ -309,8 +312,8 @@ and print_enode pri fmt e = match e.e_node with
fprintf fmt "abstract %a@ { %a }" print_expr e print_post q
| Eghost e ->
fprintf fmt "ghost@ %a" print_expr e
| Eany tyc ->
fprintf fmt "any@ %a" print_type_c tyc
| Eany spec ->
fprintf fmt "any@ %a" (print_type_c spec) e.e_vty
and print_branch fmt ({ ppat_pattern = p }, e) =
fprintf fmt "@[<hov 4>| %a ->@ %a@]" print_pat p print_expr e;
......@@ -388,12 +391,12 @@ let print_data_decl fst fmt (ts,csl) =
(print_head fst) ts (print_list newline print_constr) csl;
forget_tvs_regs ()
let print_val_decl fmt { val_name = lv ; val_spec = tyv } =
fprintf fmt "@[<hov 2>val (%a) :@ %a@]" print_lv lv print_type_v tyv;
let print_val_decl fmt { val_sym = lv ; val_vty = vty } =
fprintf fmt "@[<hov 2>val (%a) :@ %a@]" print_lv lv print_type_v vty;
(* FIXME: don't forget global regions *)
forget_tvs_regs ()
let print_let_decl fmt { let_var = lv ; let_expr = e } =
let print_let_decl fmt { let_sym = lv ; let_expr = e } =
fprintf fmt "@[<hov 2>let %a =@ %a@]" print_lv lv print_expr e;
(* FIXME: don't forget global regions *)
forget_tvs_regs ()
......@@ -438,6 +441,9 @@ let () = Exn_printer.register
fprintf fmt "Region %a is used twice" print_reg r
| Mlw_ty.UnboundRegion r ->
fprintf fmt "Unbound region %a" print_reg r
| Mlw_ty.UnboundException xs ->
fprintf fmt "This function raises %a but does not \
specify a post-condition for it" print_xs xs
| Mlw_ty.RegionMismatch (r1,r2) ->
fprintf fmt "Region mismatch between %a and %a"
print_regty r1 print_regty r2
......@@ -467,10 +473,5 @@ let () = Exn_printer.register
fprintf fmt "This expression is not a function and cannot be applied"
| Mlw_expr.Immutable _e ->
fprintf fmt "Mutable expression expected"
| Mlw_expr.UnboundException xs ->
fprintf fmt "This function raises %a but does not \
specify a post-condition for it" print_xs xs
| Mlw_expr.DuplicateArg pv ->
fprintf fmt "Argument %a is used twice" print_pv pv
| _ -> raise exn
end
......@@ -53,9 +53,6 @@ val print_ppat : formatter -> ppattern -> unit (* program patterns *)
val print_expr : formatter -> expr -> unit (* expression *)
val print_type_c : formatter -> type_c -> unit
val print_type_v : formatter -> type_v -> unit
val print_ty_decl : formatter -> itysymbol -> unit
val print_data_decl : formatter -> data_decl -> unit
val print_next_data_decl : formatter -> data_decl -> unit
......
......@@ -34,6 +34,8 @@ module rec T : sig
vars_reg : Reg.S.t;
}
type varmap = varset Mid.t
type itysymbol = {
its_pure : tysymbol;
its_args : tvsymbol list;
......@@ -66,6 +68,8 @@ end = struct
vars_reg : Reg.S.t;
}
type varmap = varset Mid.t
type itysymbol = {
its_pure : tysymbol;
its_args : tvsymbol list;
......@@ -126,6 +130,8 @@ let vars_union s1 s2 = {
vars_reg = Sreg.union s1.vars_reg s2.vars_reg;
}
let vars_merge = Mid.fold (fun _ -> vars_union)
let create_varset tvs regs = {
vars_tv = Sreg.fold (fun r -> Stv.union r.reg_ity.ity_vars.vars_tv) regs tvs;
vars_reg = regs;
......@@ -621,7 +627,74 @@ let eff_filter vars e =
eff_resets = Mreg.mapi_filter reset e.eff_resets;
}
(* program types *)
(** specification *)
type pre = term (* precondition: pre_fmla *)
type post = term (* postcondition: eps result . post_fmla *)
type xpost = post Mexn.t (* exceptional postconditions *)
let create_post vs f = t_eps_close vs f
let open_post f = match f.t_node with
| Teps bf -> t_open_bound bf
| _ -> Loc.errorm "invalid post-condition"
let check_post ty f = match f.t_node with
| Teps _ -> Ty.ty_equal_check ty (t_type f)
| _ -> Loc.errorm "invalid post-condition"
type spec = {
c_pre : pre;
c_post : post;
c_xpost : xpost;
c_effect : effect;
}
let spec_empty ty = {
c_pre = t_true;
c_post = create_post (create_vsymbol (id_fresh "dummy") ty) t_true;
c_xpost = Mexn.empty;
c_effect = eff_empty;
}
let spec_full_inst sbs tvm vsm c =
let subst = t_ty_subst tvm vsm in {
c_pre = subst c.c_pre;
c_post = subst c.c_post;
c_xpost = Mexn.map subst c.c_xpost;
c_effect = eff_full_inst sbs c.c_effect;
}
let spec_subst sbs c =
let subst = t_subst sbs in {
c_pre = subst c.c_pre;
c_post = subst c.c_post;
c_xpost = Mexn.map subst c.c_xpost;
c_effect = c.c_effect;
}
let spec_filter varm vars c =
let add _ f s = Mvs.set_union f.t_vars s in
let vss = add () c.c_pre c.c_post.t_vars in
let vss = Mexn.fold add c.c_xpost vss in
let check { vs_name = id } _ = if not (Mid.mem id varm) then
Loc.errorm "Local variable %s escapes from its scope" id.id_string in
Mvs.iter check vss;
{ c with c_effect = eff_filter vars c.c_effect }
exception UnboundException of xsymbol
let spec_check c ty =
if c.c_pre.t_ty <> None then
Loc.error ?loc:c.c_pre.t_loc (Term.FmlaExpected c.c_pre);
check_post ty c.c_post;
Mexn.iter (fun xs q -> check_post (ty_of_ity xs.xs_ity) q) c.c_xpost;
let sexn = Sexn.union c.c_effect.eff_raises c.c_effect.eff_ghostx in
let sexn = Mexn.set_diff sexn c.c_xpost in
if not (Sexn.is_empty sexn) then
raise (UnboundException (Sexn.choose sexn))
(** program variables *)
type vty_value = {
vtv_ity : ity;
......@@ -630,27 +703,6 @@ type vty_value = {
vtv_vars : varset;
}
type vty =
| VTvalue of vty_value
| VTarrow of vty_arrow
and vty_arrow = {
vta_arg : vty_value;
vta_result : vty;
vta_effect : effect;
vta_ghost : bool;
vta_vars : varset;
(* this varset covers every type variable and region in vta_arg
and vta_result, but may skip some type variables and regions
in vta_effect *)
}
(* smart constructors *)
let vty_vars s = function
| VTvalue vtv -> vars_union s vtv.vtv_vars
| VTarrow vta -> vars_union s vta.vta_vars
let vty_value ?(ghost=false) ?mut ity =
let vars = ity.ity_vars in
let vars = match mut with
......@@ -670,27 +722,163 @@ let vtv_unmut vtv =
if vtv.vtv_mut = None then vtv else
vty_value ~ghost:vtv.vtv_ghost vtv.vtv_ity
type pvsymbol = {
pv_vs : vsymbol;
pv_vtv : vty_value;
}
module PVsym = WeakStructMake (struct
type t = pvsymbol
let tag pv = pv.pv_vs.vs_name.id_tag
end)
module Spv = PVsym.S
module Mpv = PVsym.M
module Hpv = PVsym.H
module Wpv = PVsym.W
let pv_equal : pvsymbol -> pvsymbol -> bool = (==)
let create_pvsymbol id vtv = {
pv_vs = create_vsymbol id (ty_of_ity vtv.vtv_ity);
pv_vtv = vtv;
}
let create_pvsymbol, restore_pv =
let vs_to_pv = Wvs.create 17 in
(fun id vtv ->
let pv = create_pvsymbol id vtv in
Wvs.set vs_to_pv pv.pv_vs pv;
pv),
(fun vs -> try Wvs.find vs_to_pv vs with Not_found ->
Loc.error ?loc:vs.vs_name.id_loc (Decl.UnboundVar vs))
(** program types *)
type vty =
| VTvalue of vty_value
| VTarrow of vty_arrow
and vty_arrow = {
vta_args : pvsymbol list;
vta_result : vty;
vta_spec : spec;
vta_ghost : bool;
vta_vars : varset;
(* this varset covers every type variable and region in vta_arg
and vta_result, but may skip some type variables and regions
in vta_effect *)
}
let vty_vars = function
| VTvalue vtv -> vtv.vtv_vars
| VTarrow vta -> vta.vta_vars
let vty_ghost = function
| VTvalue vtv -> vtv.vtv_ghost
| VTarrow vta -> vta.vta_ghost
let vty_arrow vtv ?(effect=eff_empty) ?(ghost=false) vty =
(* mutable arguments are rejected outright *)
if vtv.vtv_mut <> None then
Loc.errorm "Mutable arguments are not allowed in vty_arrow";
(* we accept a mutable vty_value as a result to simplify Mlw_expr,
but erase it in the signature: only projections return mutables *)
let vty = match vty with
| VTvalue v -> VTvalue (vtv_unmut v)
| VTarrow _ -> vty
in {
vta_arg = vtv;
let ity_of_vty = function
| VTvalue vtv -> vtv.vtv_ity
| VTarrow _ -> ity_unit
let ty_of_vty = function
| VTvalue vtv -> ty_of_ity vtv.vtv_ity
| VTarrow _ -> ty_unit
let spec_check spec vty = spec_check spec (ty_of_vty vty)
let vty_arrow_unsafe argl ~spec ~ghost vty =
let add_arg vars { pv_vtv = vtv } = vars_union vars vtv.vtv_vars in
{ vta_args = argl;
vta_result = vty;
vta_effect = effect;
vta_spec = spec;
vta_ghost = ghost || vty_ghost vty;
vta_vars = vty_vars vtv.vtv_vars vty;
vta_vars = List.fold_left add_arg (vty_vars vty) argl;
}
let vty_arrow argl ?spec ?(ghost=false) vty =
(* we accept a mutable vty_value as a result to simplify Mlw_expr,
but drop it in the signature: only projections return mutables *)
let vty = match vty with
| VTvalue v -> VTvalue (vtv_unmut v)
| VTarrow _ -> vty in
(* the arguments must be all distinct and at least one must be given *)
if argl = [] then invalid_arg "Mlw.vty_arrow";
let add_arg pvs pv =
(* mutable arguments are rejected outright *)
if pv.pv_vtv.vtv_mut <> None then invalid_arg "Mlw.vty_arrow";
Spv.add_new (Invalid_argument "Mlw.vty_arrow") pv pvs in
ignore (List.fold_left add_arg Spv.empty argl);
let spec = match spec with
| Some spec -> spec_check spec vty; spec
| None -> spec_empty (ty_of_vty vty) in
vty_arrow_unsafe argl ~spec ~ghost vty
(* this only compares the types of arguments and results, and ignores
the spec. In other words, only the type variables and regions
in .vta_vars are matched. The caller should supply a "freezing"
substitution that covers all external type variables and regions. *)
let rec vta_vars_match s a1 a2 =
let vtv_match s v1 v2 = ity_match s v1.vtv_ity v2.vtv_ity in
let rec match_args s l1 l2 = match l1, l2 with
| [],[] -> s, a1.vta_result, a2.vta_result
| [], _ -> s, a1.vta_result, VTarrow { a2 with vta_args = l2 }
| _, [] -> s, VTarrow { a1 with vta_args = l1 }, a2.vta_result
| {pv_vtv = v1}::l1, {pv_vtv = v2}::l2 ->
match_args (vtv_match s v1 v2) l1 l2
in
let s, vty1, vty2 = match_args s a1.vta_args a2.vta_args in
match vty1, vty2 with
| VTarrow a1, VTarrow a2 -> vta_vars_match s a1 a2
| VTvalue v1, VTvalue v2 -> vtv_match s v1 v2
| _ -> invalid_arg "Mlw_ty.vta_vars_match"
(* the substitution must cover not only vta.vta_tvs and vta.vta_regs
but also every type variable and every region in vta_spec *)
let vta_full_inst sbs vta =
let tvm = Mtv.map ty_of_ity sbs.ity_subst_tv in
let vtv_inst { vtv_ity = ity; vtv_ghost = ghost } =
vty_value ~ghost (ity_full_inst sbs ity) in
let pv_inst { pv_vs = vs; pv_vtv = vtv } =
create_pvsymbol (id_clone vs.vs_name) (vtv_inst vtv) in
let add_arg vsm pv =
let nv = pv_inst pv in
Mvs.add pv.pv_vs (t_var nv.pv_vs) vsm, nv in
let rec vta_inst vsm vta =
let vsm, args = Util.map_fold_left add_arg vsm vta.vta_args in
let spec = spec_full_inst sbs tvm vsm vta.vta_spec in
let vty = match vta.vta_result with
| VTarrow vta -> VTarrow (vta_inst vsm vta)
| VTvalue vtv -> VTvalue (vtv_inst vtv) in
vty_arrow_unsafe args ~ghost:vta.vta_ghost ~spec vty
in
vta_inst Mvs.empty vta