Commit c0696892 authored by Andrei Paskevich's avatar Andrei Paskevich
Browse files

whyml: back to the KISS principle

Merge specifications into program types, as JCF intended.
parent 14c2c552
...@@ -178,7 +178,7 @@ let letvar_news = function ...@@ -178,7 +178,7 @@ let letvar_news = function
| LetA ps -> check_vars ps.ps_vars; Sid.singleton ps.ps_name | LetA ps -> check_vars ps.ps_vars; Sid.singleton ps.ps_name
let create_let_decl ld = let create_let_decl ld =
let news = letvar_news ld.let_var in let news = letvar_news ld.let_sym in
(* (*
let syms = syms_varmap Sid.empty ld.let_expr.e_vars in let syms = syms_varmap Sid.empty ld.let_expr.e_vars in
let syms = syms_effect syms ld.let_expr.e_effect in let syms = syms_effect syms ld.let_expr.e_effect in
...@@ -207,7 +207,7 @@ let create_rec_decl rdl = ...@@ -207,7 +207,7 @@ let create_rec_decl rdl =
mk_decl (PDrec rdl) (*syms*) news mk_decl (PDrec rdl) (*syms*) news
let create_val_decl vd = let create_val_decl vd =
let news = letvar_news vd.val_name in let news = letvar_news vd.val_sym in
(* (*
let syms = syms_type_v Sid.empty vd.val_spec in let syms = syms_type_v Sid.empty vd.val_spec in
let syms = syms_varmap syms vd.val_vars in let syms = syms_varmap syms vd.val_vars in
......
...@@ -202,8 +202,9 @@ let unify d1 d2 = unify ~weak:false d1 d2 ...@@ -202,8 +202,9 @@ let unify d1 d2 = unify ~weak:false d1 d2
type dvty = dity list * dity (* A -> B -> C == ([A;B],C) *) type dvty = dity list * dity (* A -> B -> C == ([A;B],C) *)
let vty_of_dvty (argl,res) = let vty_of_dvty (argl,res) =
let add a v = VTarrow (vty_arrow (vty_value (ity_of_dity a)) v) in let vtv = VTvalue (vty_value (ity_of_dity res)) in
List.fold_right add argl (VTvalue (vty_value (ity_of_dity res))) let conv a = create_pvsymbol (id_fresh "x") (vty_value (ity_of_dity a)) in
if argl = [] then vtv else VTarrow (vty_arrow (List.map conv argl) vtv)
type tvars = dity list type tvars = dity list
...@@ -284,14 +285,14 @@ let specialize_xsymbol xs = ...@@ -284,14 +285,14 @@ let specialize_xsymbol xs =
let specialize_vtarrow vars vta = let specialize_vtarrow vars vta =
let htv = Htv.create 3 and hreg = Hreg.create 3 in let htv = Htv.create 3 and hreg = Hreg.create 3 in
let conv vtv = dity_of_vtv htv hreg vars vtv in let conv pv = dity_of_vtv htv hreg vars pv.pv_vtv in
let rec specialize a = let rec specialize a =
let arg = conv a.vta_arg in let argl = List.map conv a.vta_args in
let argl,res = match a.vta_result with let narg,res = match a.vta_result with
| VTvalue v -> [], conv v | VTvalue v -> [], dity_of_vtv htv hreg vars v
| VTarrow a -> specialize a | VTarrow a -> specialize a
in in
arg::argl, res argl @ narg, res
in in
specialize vta specialize vta
......
This diff is collapsed.
...@@ -27,26 +27,6 @@ open Term ...@@ -27,26 +27,6 @@ open Term
open Mlw_ty open Mlw_ty
open Mlw_ty.T open Mlw_ty.T
(** program variables *)
(* pvsymbols represent function arguments and pattern variables *)
type pvsymbol = private {
pv_vs : vsymbol;
pv_vtv : vty_value;
}
module Mpv : Map.S with type key = pvsymbol
module Spv : Mpv.Set
module Hpv : Hashtbl.S with type key = pvsymbol
module Wpv : Hashweak.S with type key = pvsymbol
val pv_equal : pvsymbol -> pvsymbol -> bool
val create_pvsymbol : preid -> vty_value -> pvsymbol
val restore_pv : vsymbol -> pvsymbol
(** program symbols *) (** program symbols *)
(* psymbols represent lambda-abstractions. They are polymorphic and (* psymbols represent lambda-abstractions. They are polymorphic and
...@@ -59,17 +39,12 @@ type psymbol = private { ...@@ -59,17 +39,12 @@ type psymbol = private {
ps_vars : varset; ps_vars : varset;
(* this varset covers the type variables and regions of the defining (* this varset covers the type variables and regions of the defining
lambda that cannot be instantiated. Every other type variable lambda that cannot be instantiated. Every other type variable
and region in ps_vty is generalized and can be instantiated. *) and region in ps_vta is generalized and can be instantiated. *)
ps_subst : ity_subst; ps_subst : ity_subst;
(* this substitution instantiates every type variable and region (* this substitution instantiates every type variable and region
in ps_vars to itself *) in ps_vars to itself *)
} }
module Mps : Map.S with type key = psymbol
module Sps : Mps.Set
module Hps : Hashtbl.S with type key = psymbol
module Wps : Hashweak.S with type key = psymbol
val ps_equal : psymbol -> psymbol -> bool val ps_equal : psymbol -> psymbol -> bool
val create_psymbol : preid -> vty_arrow -> psymbol val create_psymbol : preid -> vty_arrow -> psymbol
...@@ -102,39 +77,17 @@ exception HiddenPLS of lsymbol ...@@ -102,39 +77,17 @@ exception HiddenPLS of lsymbol
(** specification *) (** specification *)
type pre = term (* precondition: pre_fmla *) type let_sym =
type post = term (* postcondition: eps result . post_fmla *)
type xpost = post Mexn.t (* exceptional postconditions *)
val create_post : vsymbol -> term -> post
val open_post : post -> vsymbol * term
type type_c = {
c_pre : pre;
c_effect : effect;
c_result : type_v;
c_post : post;
c_xpost : xpost;
}
and type_v =
| SpecV of vty_value
| SpecA of pvsymbol list * type_c
type let_var =
| LetV of pvsymbol | LetV of pvsymbol
| LetA of psymbol | LetA of psymbol
type val_decl = private { type val_decl = private {
val_name : let_var; val_sym : let_sym;
val_spec : type_v; val_vty : vty;
val_vars : varset Mid.t; val_vars : varmap;
} }
val create_val : Ident.preid -> type_v -> val_decl val create_val : Ident.preid -> vty -> val_decl
exception DuplicateArg of pvsymbol
exception UnboundException of xsymbol
(** patterns *) (** patterns *)
...@@ -175,24 +128,23 @@ type expr = private { ...@@ -175,24 +128,23 @@ type expr = private {
e_node : expr_node; e_node : expr_node;
e_vty : vty; e_vty : vty;
e_effect : effect; e_effect : effect;
e_vars : varset Mid.t; e_vars : varmap;
e_label : Slab.t; e_label : Slab.t;
e_loc : Loc.position option; e_loc : Loc.position option;
e_tag : Hashweak.tag;
} }
and expr_node = private and expr_node = private
| Elogic of term | Elogic of term
| Evalue of pvsymbol | Evalue of pvsymbol
| Earrow of psymbol | Earrow of psymbol
| Eapp of expr * pvsymbol | Eapp of expr * pvsymbol * spec
| Elet of let_defn * expr | Elet of let_defn * expr
| Erec of rec_defn list * expr | Erec of rec_defn list * expr
| Eif of expr * expr * expr | Eif of expr * expr * expr
| Ecase of expr * (ppattern * expr) list | Ecase of expr * (ppattern * expr) list
| Eassign of expr * region * pvsymbol | Eassign of expr * region * pvsymbol
| Eghost of expr | Eghost of expr
| Eany of type_c | Eany of spec
| Eloop of invariant * variant list * expr | Eloop of invariant * variant list * expr
| Efor of pvsymbol * for_bounds * invariant * expr | Efor of pvsymbol * for_bounds * invariant * expr
| Eraise of xsymbol * expr | Eraise of xsymbol * expr
...@@ -202,14 +154,14 @@ and expr_node = private ...@@ -202,14 +154,14 @@ and expr_node = private
| Eabsurd | Eabsurd
and let_defn = private { and let_defn = private {
let_var : let_var; let_sym : let_sym;
let_expr : expr; let_expr : expr;
} }
and rec_defn = private { and rec_defn = private {
rec_ps : psymbol; rec_ps : psymbol;
rec_lambda : lambda; rec_lambda : lambda;
rec_vars : varset Mid.t; rec_vars : varmap;
} }
and lambda = { and lambda = {
...@@ -226,11 +178,6 @@ and variant = { ...@@ -226,11 +178,6 @@ and variant = {
v_rel : lsymbol option; (* tau tau : prop *) v_rel : lsymbol option; (* tau tau : prop *)
} }
module Mexpr : Map.S with type key = expr
module Sexpr : Mexpr.Set
module Hexpr : Hashtbl.S with type key = expr
module Wexpr : Hashweak.S with type key = expr
val e_label : ?loc:Loc.position -> Slab.t -> expr -> expr val e_label : ?loc:Loc.position -> Slab.t -> expr -> expr
val e_label_add : label -> expr -> expr val e_label_add : label -> expr -> expr
val e_label_copy : expr -> expr -> expr val e_label_copy : expr -> expr -> expr
...@@ -269,7 +216,7 @@ exception Immutable of expr ...@@ -269,7 +216,7 @@ exception Immutable of expr
val e_assign : expr -> expr -> expr val e_assign : expr -> expr -> expr
val e_ghost : expr -> expr val e_ghost : expr -> expr
val e_any : type_c -> expr val e_any : spec -> vty -> expr
val e_void : expr val e_void : expr
......
...@@ -311,7 +311,7 @@ let add_pdecl uc d = ...@@ -311,7 +311,7 @@ let add_pdecl uc d =
let defn cl = List.map constructor cl in let defn cl = List.map constructor cl in
let dl = List.map (fun (its,cl) -> its.its_pure, defn cl) dl in let dl = List.map (fun (its,cl) -> its.its_pure, defn cl) dl in
add_to_theory Theory.add_data_decl uc dl add_to_theory Theory.add_data_decl uc dl
| PDval { val_name = lv } | PDlet { let_var = lv } -> | PDval { val_sym = lv } | PDlet { let_sym = lv } ->
add_let uc lv add_let uc lv
| PDrec rdl -> | PDrec rdl ->
List.fold_left add_rec uc rdl List.fold_left add_rec uc rdl
......
...@@ -130,8 +130,9 @@ let print_vtv fmt vtv = ...@@ -130,8 +130,9 @@ let print_vtv fmt vtv =
fprintf fmt "%s%a" (if vtv.vtv_ghost then "?" else "") print_ity vtv.vtv_ity fprintf fmt "%s%a" (if vtv.vtv_ghost then "?" else "") print_ity vtv.vtv_ity
let rec print_vta fmt vta = let rec print_vta fmt vta =
fprintf fmt "%a ->@ %a%a" print_vtv vta.vta_arg let print_arg fmt pv = fprintf fmt "%a ->@ " print_vtv pv.pv_vtv in
print_effect vta.vta_effect print_vty vta.vta_result fprintf fmt "%a%a%a" (print_list nothing print_arg) vta.vta_args
print_effect vta.vta_spec.c_effect print_vty vta.vta_result
and print_vty fmt = function and print_vty fmt = function
| VTarrow vta -> print_vta fmt vta | VTarrow vta -> print_vta fmt vta
...@@ -168,18 +169,20 @@ let forget_lv = function ...@@ -168,18 +169,20 @@ let forget_lv = function
| LetA ps -> forget_ps ps | LetA ps -> forget_ps ps
let rec print_type_v fmt = function let rec print_type_v fmt = function
| SpecV vtv -> print_vtv fmt vtv | VTvalue vtv -> print_vtv fmt vtv
| SpecA (pvl,tyc) -> | VTarrow vta ->
let print_arg fmt pv = fprintf fmt "@[(%a)@] ->@ " print_pvty pv in let print_arg fmt pv = fprintf fmt "@[(%a)@] ->@ " print_pvty pv in
fprintf fmt "%a%a" (print_list nothing print_arg) pvl print_type_c tyc; fprintf fmt "%a%a"
List.iter forget_pv pvl (print_list nothing print_arg) vta.vta_args
(print_type_c vta.vta_spec) vta.vta_result;
List.iter forget_pv vta.vta_args
and print_type_c fmt tyc = and print_type_c spec fmt vty =
fprintf fmt "{ %a }@ %a%a@ { %a }" fprintf fmt "{ %a }@ %a%a@ { %a }"
print_term tyc.c_pre print_term spec.c_pre
print_effect tyc.c_effect print_effect spec.c_effect
print_type_v tyc.c_result print_type_v vty
print_post tyc.c_post print_post spec.c_post
(* TODO: print_xpost *) (* TODO: print_xpost *)
let print_invariant fmt f = let print_invariant fmt f =
...@@ -262,14 +265,14 @@ and print_enode pri fmt e = match e.e_node with ...@@ -262,14 +265,14 @@ and print_enode pri fmt e = match e.e_node with
print_pv fmt v print_pv fmt v
| Earrow a -> | Earrow a ->
print_ps fmt a print_ps fmt a
| Eapp (e,v) -> | Eapp (e,v,_) ->
fprintf fmt "(%a@ %a)" (print_lexpr pri) e print_pv v fprintf fmt "(%a@ %a)" (print_lexpr pri) e print_pv v
| Elet ({ let_var = LetV pv ; let_expr = e1 }, e2) | Elet ({ let_sym = LetV pv ; let_expr = e1 }, e2)
when pv.pv_vs.vs_name.id_string = "_" && when pv.pv_vs.vs_name.id_string = "_" &&
ity_equal pv.pv_vtv.vtv_ity ity_unit -> ity_equal pv.pv_vtv.vtv_ity ity_unit ->
fprintf fmt (protect_on (pri > 0) "%a;@\n%a") fprintf fmt (protect_on (pri > 0) "%a;@\n%a")
print_expr e1 print_expr e2; print_expr e1 print_expr e2;
| Elet ({ let_var = lv ; let_expr = e1 }, e2) -> | Elet ({ let_sym = lv ; let_expr = e1 }, e2) ->
fprintf fmt (protect_on (pri > 0) "@[<hov 2>let %a =@ %a@ in@]@\n%a") fprintf fmt (protect_on (pri > 0) "@[<hov 2>let %a =@ %a@ in@]@\n%a")
print_lv lv (print_lexpr 4) e1 print_expr e2; print_lv lv (print_lexpr 4) e1 print_expr e2;
forget_lv lv forget_lv lv
...@@ -309,8 +312,8 @@ and print_enode pri fmt e = match e.e_node with ...@@ -309,8 +312,8 @@ and print_enode pri fmt e = match e.e_node with
fprintf fmt "abstract %a@ { %a }" print_expr e print_post q fprintf fmt "abstract %a@ { %a }" print_expr e print_post q
| Eghost e -> | Eghost e ->
fprintf fmt "ghost@ %a" print_expr e fprintf fmt "ghost@ %a" print_expr e
| Eany tyc -> | Eany spec ->
fprintf fmt "any@ %a" print_type_c tyc fprintf fmt "any@ %a" (print_type_c spec) e.e_vty
and print_branch fmt ({ ppat_pattern = p }, e) = and print_branch fmt ({ ppat_pattern = p }, e) =
fprintf fmt "@[<hov 4>| %a ->@ %a@]" print_pat p print_expr e; fprintf fmt "@[<hov 4>| %a ->@ %a@]" print_pat p print_expr e;
...@@ -388,12 +391,12 @@ let print_data_decl fst fmt (ts,csl) = ...@@ -388,12 +391,12 @@ let print_data_decl fst fmt (ts,csl) =
(print_head fst) ts (print_list newline print_constr) csl; (print_head fst) ts (print_list newline print_constr) csl;
forget_tvs_regs () forget_tvs_regs ()
let print_val_decl fmt { val_name = lv ; val_spec = tyv } = let print_val_decl fmt { val_sym = lv ; val_vty = vty } =
fprintf fmt "@[<hov 2>val (%a) :@ %a@]" print_lv lv print_type_v tyv; fprintf fmt "@[<hov 2>val (%a) :@ %a@]" print_lv lv print_type_v vty;
(* FIXME: don't forget global regions *) (* FIXME: don't forget global regions *)
forget_tvs_regs () forget_tvs_regs ()
let print_let_decl fmt { let_var = lv ; let_expr = e } = let print_let_decl fmt { let_sym = lv ; let_expr = e } =
fprintf fmt "@[<hov 2>let %a =@ %a@]" print_lv lv print_expr e; fprintf fmt "@[<hov 2>let %a =@ %a@]" print_lv lv print_expr e;
(* FIXME: don't forget global regions *) (* FIXME: don't forget global regions *)
forget_tvs_regs () forget_tvs_regs ()
...@@ -438,6 +441,9 @@ let () = Exn_printer.register ...@@ -438,6 +441,9 @@ let () = Exn_printer.register
fprintf fmt "Region %a is used twice" print_reg r fprintf fmt "Region %a is used twice" print_reg r
| Mlw_ty.UnboundRegion r -> | Mlw_ty.UnboundRegion r ->
fprintf fmt "Unbound region %a" print_reg r fprintf fmt "Unbound region %a" print_reg r
| Mlw_ty.UnboundException xs ->
fprintf fmt "This function raises %a but does not \
specify a post-condition for it" print_xs xs
| Mlw_ty.RegionMismatch (r1,r2) -> | Mlw_ty.RegionMismatch (r1,r2) ->
fprintf fmt "Region mismatch between %a and %a" fprintf fmt "Region mismatch between %a and %a"
print_regty r1 print_regty r2 print_regty r1 print_regty r2
...@@ -467,10 +473,5 @@ let () = Exn_printer.register ...@@ -467,10 +473,5 @@ let () = Exn_printer.register
fprintf fmt "This expression is not a function and cannot be applied" fprintf fmt "This expression is not a function and cannot be applied"
| Mlw_expr.Immutable _e -> | Mlw_expr.Immutable _e ->
fprintf fmt "Mutable expression expected" fprintf fmt "Mutable expression expected"
| Mlw_expr.UnboundException xs ->
fprintf fmt "This function raises %a but does not \
specify a post-condition for it" print_xs xs
| Mlw_expr.DuplicateArg pv ->
fprintf fmt "Argument %a is used twice" print_pv pv
| _ -> raise exn | _ -> raise exn
end end
...@@ -53,9 +53,6 @@ val print_ppat : formatter -> ppattern -> unit (* program patterns *) ...@@ -53,9 +53,6 @@ val print_ppat : formatter -> ppattern -> unit (* program patterns *)
val print_expr : formatter -> expr -> unit (* expression *) val print_expr : formatter -> expr -> unit (* expression *)
val print_type_c : formatter -> type_c -> unit
val print_type_v : formatter -> type_v -> unit
val print_ty_decl : formatter -> itysymbol -> unit val print_ty_decl : formatter -> itysymbol -> unit
val print_data_decl : formatter -> data_decl -> unit val print_data_decl : formatter -> data_decl -> unit
val print_next_data_decl : formatter -> data_decl -> unit val print_next_data_decl : formatter -> data_decl -> unit
......
...@@ -34,6 +34,8 @@ module rec T : sig ...@@ -34,6 +34,8 @@ module rec T : sig
vars_reg : Reg.S.t; vars_reg : Reg.S.t;
} }
type varmap = varset Mid.t
type itysymbol = { type itysymbol = {
its_pure : tysymbol; its_pure : tysymbol;
its_args : tvsymbol list; its_args : tvsymbol list;
...@@ -66,6 +68,8 @@ end = struct ...@@ -66,6 +68,8 @@ end = struct
vars_reg : Reg.S.t; vars_reg : Reg.S.t;
} }
type varmap = varset Mid.t
type itysymbol = { type itysymbol = {
its_pure : tysymbol; its_pure : tysymbol;
its_args : tvsymbol list; its_args : tvsymbol list;
...@@ -126,6 +130,8 @@ let vars_union s1 s2 = { ...@@ -126,6 +130,8 @@ let vars_union s1 s2 = {
vars_reg = Sreg.union s1.vars_reg s2.vars_reg; vars_reg = Sreg.union s1.vars_reg s2.vars_reg;
} }
let vars_merge = Mid.fold (fun _ -> vars_union)
let create_varset tvs regs = { let create_varset tvs regs = {
vars_tv = Sreg.fold (fun r -> Stv.union r.reg_ity.ity_vars.vars_tv) regs tvs; vars_tv = Sreg.fold (fun r -> Stv.union r.reg_ity.ity_vars.vars_tv) regs tvs;
vars_reg = regs; vars_reg = regs;
...@@ -621,7 +627,74 @@ let eff_filter vars e = ...@@ -621,7 +627,74 @@ let eff_filter vars e =
eff_resets = Mreg.mapi_filter reset e.eff_resets; eff_resets = Mreg.mapi_filter reset e.eff_resets;
} }
(* program types *) (** specification *)
type pre = term (* precondition: pre_fmla *)
type post = term (* postcondition: eps result . post_fmla *)
type xpost = post Mexn.t (* exceptional postconditions *)
let create_post vs f = t_eps_close vs f
let open_post f = match f.t_node with
| Teps bf -> t_open_bound bf
| _ -> Loc.errorm "invalid post-condition"
let check_post ty f = match f.t_node with
| Teps _ -> Ty.ty_equal_check ty (t_type f)
| _ -> Loc.errorm "invalid post-condition"
type spec = {
c_pre : pre;
c_post : post;
c_xpost : xpost;
c_effect : effect;
}
let spec_empty ty = {
c_pre = t_true;
c_post = create_post (create_vsymbol (id_fresh "dummy") ty) t_true;
c_xpost = Mexn.empty;
c_effect = eff_empty;
}
let spec_full_inst sbs tvm vsm c =
let subst = t_ty_subst tvm vsm in {
c_pre = subst c.c_pre;
c_post = subst c.c_post;
c_xpost = Mexn.map subst c.c_xpost;
c_effect = eff_full_inst sbs c.c_effect;
}
let spec_subst sbs c =
let subst = t_subst sbs in {
c_pre = subst c.c_pre;
c_post = subst c.c_post;
c_xpost = Mexn.map subst c.c_xpost;
c_effect = c.c_effect;
}
let spec_filter varm vars c =
let add _ f s = Mvs.set_union f.t_vars s in
let vss = add () c.c_pre c.c_post.t_vars in
let vss = Mexn.fold add c.c_xpost vss in
let check { vs_name = id } _ = if not (Mid.mem id varm) then
Loc.errorm "Local variable %s escapes from its scope" id.id_string in
Mvs.iter check vss;
{ c with c_effect = eff_filter vars c.c_effect }
exception UnboundException of xsymbol
let spec_check c ty =
if c.c_pre.t_ty <> None then
Loc.error ?loc:c.c_pre.t_loc (Term.FmlaExpected c.c_pre);
check_post ty c.c_post;
Mexn.iter (fun xs q -> check_post (ty_of_ity xs.xs_ity) q) c.c_xpost;
let sexn = Sexn.union c.c_effect.eff_raises c.c_effect.eff_ghostx in
let sexn = Mexn.set_diff sexn c.c_xpost in
if not (Sexn.is_empty sexn) then
raise (UnboundException (Sexn.choose sexn))
(** program variables *)
type vty_value = { type vty_value = {
vtv_ity : ity; vtv_ity : ity;
...@@ -630,27 +703,6 @@ type vty_value = { ...@@ -630,27 +703,6 @@ type vty_value = {
vtv_vars : varset; vtv_vars : varset;
} }
type vty =
| VTvalue of vty_value
| VTarrow of vty_arrow
and vty_arrow = {
vta_arg : vty_value;
vta_result : vty;
vta_effect : effect;
vta_ghost : bool;
vta_vars : varset;
(* this varset covers every type variable and region in vta_arg
and vta_result, but may skip some type variables and regions
in vta_effect *)
}
(* smart constructors *)
let vty_vars s = function
| VTvalue vtv -> vars_union s vtv.vtv_vars
| VTarrow vta -> vars_union s vta.vta_vars
let vty_value ?(ghost=false) ?mut ity = let vty_value ?(ghost=false) ?mut ity =
let vars = ity.ity_vars in let vars = ity.ity_vars in
let vars = match mut with let vars = match mut with
...@@ -670,27 +722,163 @@ let vtv_unmut vtv = ...@@ -670,27 +722,163 @@ let vtv_unmut vtv =
if vtv.vtv_mut = None then vtv else