Commit ed26051a authored by Andrei Paskevich's avatar Andrei Paskevich

count the number of variable occurrences in t_vars

parent fe3e4378
......@@ -40,9 +40,8 @@ type logic_decl = lsymbol * ls_defn option
exception UnboundVar of vsymbol
let check_fvs f =
let fvs = t_freevars Svs.empty (t_prop f) in
Svs.iter (fun vs -> raise (UnboundVar vs)) fvs;
f
Mvs.iter (fun vs _ -> raise (UnboundVar vs)) f.t_vars;
t_prop f
let check_vl ty v = ty_equal_check ty v.vs_ty
let check_tl ty t = ty_equal_check ty (t_type t)
......@@ -568,14 +567,14 @@ let merge_known kn1 kn2 =
Mid.union check_known kn1 kn2
let known_add_decl kn0 decl =
let kn = Mid.map (fun _ -> decl) decl.d_news in
let kn = Mid.map (const decl) decl.d_news in
let check id decl0 _ =
if d_equal decl0 decl
then raise (KnownIdent id)
else raise (RedeclaredIdent id)
in
let kn = Mid.union check kn0 kn in
let unk = Mid.diff (fun _ _ _ -> None) decl.d_syms kn in
let unk = Mid.set_diff decl.d_syms kn in
if Sid.is_empty unk then kn
else raise (UnknownIdent (Sid.choose unk))
......
......@@ -128,7 +128,7 @@ module Compile (X : Action) = struct
if Mls.mem cs types then comp_cases cs al else comp_wilds ()
| _ ->
let base =
if Mls.submap (const3 true) css types then []
if Mls.set_submap css types then []
else [mk_branch (pat_wild ty) (comp_wilds ())]
in
let add cs ql acc =
......
......@@ -261,7 +261,7 @@ type term = {
t_ty : ty option;
t_label : label list;
t_loc : Loc.position option;
t_vars : Svs.t;
t_vars : int Mvs.t;
t_tag : int;
}
......@@ -286,7 +286,7 @@ and term_quant = vsymbol list * bind_info * trigger * term
and trigger = term list list
and bind_info = {
bv_vars : Svs.t; (* free variables *)
bv_vars : int Mvs.t; (* free variables *)
bv_subst : term Mvs.t (* deferred substitution *)
}
......@@ -339,6 +339,10 @@ let bnd_map_fold fn acc bv =
(* hash-consing for terms and formulas *)
let some_plus _ m n = Some (m + n)
let add_t_vars s t = Mvs.union some_plus s t.t_vars
let add_b_vars s (_,b,_) = Mvs.union some_plus s b.bv_vars
module Hsterm = Hashcons.Make (struct
type t = term
......@@ -408,21 +412,18 @@ module Hsterm = Hashcons.Make (struct
Hashcons.combine (t_hash_node t.t_node)
(Hashcons.combine_list Hashtbl.hash (oty_hash t.t_ty) t.t_label)
let add_t_vars s t = Svs.union s t.t_vars
let add_b_vars s (_,b,_) = Svs.union s b.bv_vars
let t_vars_node = function
| Tvar v -> Svs.singleton v
| Tconst _ -> Svs.empty
| Tapp (_,tl) -> List.fold_left add_t_vars Svs.empty tl
| Tvar v -> Mvs.singleton v 1
| Tconst _ -> Mvs.empty
| Tapp (_,tl) -> List.fold_left add_t_vars Mvs.empty tl
| Tif (f,t,e) -> add_t_vars (add_t_vars f.t_vars t) e
| Tlet (t,bt) -> add_b_vars t.t_vars bt
| Tcase (t,bl) -> List.fold_left add_b_vars t.t_vars bl
| Teps (_,b,_) -> b.bv_vars
| Tquant (_,(_,b,_,_)) -> b.bv_vars
| Tbinop (_,f1,f2) -> Svs.union f1.t_vars f2.t_vars
| Tbinop (_,f1,f2) -> add_t_vars f1.t_vars f2
| Tnot f -> f.t_vars
| Ttrue | Tfalse -> Svs.empty
| Ttrue | Tfalse -> Mvs.empty
let tag n t = { t with t_tag = n ; t_vars = t_vars_node t.t_node }
......@@ -443,7 +444,7 @@ let mk_term n ty = Hsterm.hashcons {
t_node = n;
t_label = [];
t_loc = None;
t_vars = Svs.empty;
t_vars = Mvs.empty;
t_ty = ty;
t_tag = -1
}
......@@ -570,13 +571,15 @@ let rec t_subst_unsafe m t =
and bv_subst_unsafe m b =
(* restrict m to the variables free in b *)
let m = Mvs.inter (fun _ t () -> Some t) m b.bv_vars in
let m = Mvs.set_inter m b.bv_vars in
(* if m is empty, return early *)
if Mvs.is_empty m then b else
(* remove from b.bv_vars the variables replaced by m *)
let s = Mvs.diff (fun _ () _ -> None) b.bv_vars m in
let s = Mvs.set_diff b.bv_vars m in
(* add to b.bv_vars the free variables added by m *)
let s = Mvs.fold (fun _ t -> Svs.union t.t_vars) m s in
let mult n s = if n = 1 then s else Mvs.map (fun i -> i * n) s in
let join _ n t s = Mvs.union some_plus (mult n t.t_vars) s in
let s = Mvs.fold2_inter join b.bv_vars m s in
(* apply m to the terms in b.bv_subst *)
let h = Mvs.map (t_subst_unsafe m) b.bv_subst in
(* join m to b.bv_subst *)
......@@ -591,14 +594,13 @@ let t_subst_unsafe m t =
let bnd_new s = { bv_vars = s ; bv_subst = Mvs.empty }
let t_close_bound v t = (v, bnd_new (Svs.remove v t.t_vars), t)
let t_close_bound v t = (v, bnd_new (Mvs.remove v t.t_vars), t)
let t_close_branch p t = (p, bnd_new (Svs.diff t.t_vars p.pat_vars), t)
let t_close_branch p t = (p, bnd_new (Mvs.set_diff t.t_vars p.pat_vars), t)
let t_close_quant vl tl f =
let del_v s v = Svs.remove v s in
let add_t s t = Svs.union s t.t_vars in
let s = tr_fold add_t f.t_vars tl in
let del_v s v = Mvs.remove v s in
let s = tr_fold add_t_vars f.t_vars tl in
let s = List.fold_left del_v s vl in
(vl, bnd_new s, tl, t_prop f)
......@@ -833,7 +835,7 @@ let t_gen_map fnT fnL mapV t = t_gen_map (Wty.memoize 17 fnT) fnL mapV t
(* map over type and logic symbols *)
let gen_mapV fnT = Mvs.mapi (fun v () -> t_var (gen_fresh_vsymbol fnT v))
let gen_mapV fnT = Mvs.mapi (fun v _ -> t_var (gen_fresh_vsymbol fnT v))
let t_s_map fnT fnL t = t_gen_map fnT fnL (gen_mapV fnT t.t_vars) t
......@@ -1040,16 +1042,10 @@ let t_v_map fn t =
let res = fn v in ty_equal_check v.vs_ty (t_type res); res in
t_subst_unsafe (Mvs.mapi fn t.t_vars) t
let t_v_fold fn acc t = Svs.fold (fun v a -> fn a v) t.t_vars acc
let t_v_all pr t = Svs.for_all pr t.t_vars
let t_v_any pr t = Svs.exists pr t.t_vars
(* looks for occurrence of a variable from set [s] in a term [t] *)
let t_occurs s t = not (Svs.is_empty (Svs.inter s t.t_vars))
let t_v_fold fn acc t = Mvs.fold (fun v _ a -> fn a v) t.t_vars acc
let t_occurs_single v t = Svs.mem v t.t_vars
let t_v_all pr t = Mvs.for_all (fun v _ -> pr v) t.t_vars
let t_v_any pr t = Mvs.exists (fun v _ -> pr v) t.t_vars
(* replaces variables with terms in term [t] using map [m] *)
......@@ -1061,7 +1057,7 @@ let t_subst_single v t1 t = t_subst (Mvs.singleton v t1) t
(* set of free variables *)
let t_freevars s t = Svs.union s t.t_vars
let t_freevars = add_t_vars
(* alpha-equality *)
......@@ -1323,32 +1319,36 @@ let t_if_simp f1 f2 f3 =
| _, _, Tfalse -> t_and_simp f1 f2
| _, _, _ -> if t_equal f2 f3 then f2 else f123
let t_let_simp e ((v,_,t) as bt) = match e.t_node with
| _ when not (Svs.mem v t.t_vars) -> snd (t_open_bound bt)
| Tvar _ -> let v,t = t_open_bound bt in t_subst_single v e t
| _ ->
begin match t.t_node with
| Tvar v' when vs_equal v v' -> e
| _ -> t_let e bt
end
let t_let_close_simp v e t = match e.t_node with
| _ when not (Svs.mem v t.t_vars) -> t
| Tvar _ -> t_subst_single v e t
| _ ->
begin match t.t_node with
| Tvar v' when vs_equal v v' -> e
| _ -> t_let_close v e t
end
let small t = match t.t_node with
| Tvar _ | Tconst _ -> true
| _ -> false
let t_let_simp e ((v,b,t) as bt) =
let n = Mvs.find_default v 0 t.t_vars in
if n = 0 then
t_subst_unsafe b.bv_subst t else
if n = 1 || small e then begin
ty_equal_check v.vs_ty (t_type e);
t_subst_unsafe (Mvs.add v e b.bv_subst) t
end else
t_let e bt
let t_let_close_simp v e t =
let n = Mvs.find_default v 0 t.t_vars in
if n = 0 then t else
if n = 1 || small e then
t_subst_single v e t
else
t_let_close v e t
let vl_filter f vl =
List.filter (fun v -> Svs.mem v f.t_vars) vl
let v_occurs f v = Mvs.mem v f.t_vars
let v_subset f e = Mvs.set_submap e.t_vars f.t_vars
let tr_filter f tl =
List.filter (List.for_all (fun e -> Svs.subset e.t_vars f.t_vars)) tl
let vl_filter f vl = List.filter (v_occurs f) vl
let tr_filter f tl = List.filter (List.for_all (v_subset f)) tl
let t_quant_simp q ((vl,_,_,f) as qf) =
if List.for_all (fun v -> Svs.mem v f.t_vars) vl then
if List.for_all (v_occurs f) vl then
t_quant q qf
else
let vl,tl,f = t_open_quant qf in
......@@ -1356,7 +1356,7 @@ let t_quant_simp q ((vl,_,_,f) as qf) =
t_quant_close q vl (tr_filter f tl) f
let t_quant_close_simp q vl tl f =
if List.for_all (fun v -> Svs.mem v f.t_vars) vl then
if List.for_all (v_occurs f) vl then
t_quant_close q vl tl f
else
let vl = vl_filter f vl in if vl = [] then f else
......
......@@ -129,7 +129,7 @@ type term = private {
t_ty : ty option;
t_label : label list;
t_loc : Loc.position option;
t_vars : Svs.t;
t_vars : int Mvs.t;
t_tag : int;
}
......@@ -366,10 +366,7 @@ val t_v_fold : ('a -> vsymbol -> 'a) -> 'a -> term -> 'a
val t_v_all : (vsymbol -> bool) -> term -> bool
val t_v_any : (vsymbol -> bool) -> term -> bool
(** Variable occurrence check and substitution *)
val t_occurs : Svs.t -> term -> bool
val t_occurs_single : vsymbol -> term -> bool
(** Variable substitution *)
val t_subst : term Mvs.t -> term -> term
val t_subst_single : vsymbol -> term -> term -> term
......@@ -378,7 +375,7 @@ val t_ty_subst : ty Mtv.t -> term Mvs.t -> term -> term
(** Find free variables and type variables *)
val t_freevars : Svs.t -> term -> Svs.t
val t_freevars : int Mvs.t -> term -> int Mvs.t
val t_ty_freevars : Stv.t -> term -> Stv.t
(** Map/fold over types and logical symbols *)
......
......@@ -677,8 +677,8 @@ let check_at_fmla loc f0 =
let v = ref None in
let rec check f = match f.t_node with
| Term.Tapp (ls, _) when ls_equal ls fs_at || ls_equal ls fs_old ->
let d = Svs.diff f.t_vars f0.t_vars in
Svs.is_empty d || (v := Some (Svs.choose d); false)
let d = Mvs.set_diff f.t_vars f0.t_vars in
Mvs.is_empty d || (v := Some (fst (Mvs.choose d)); false)
| _ ->
t_all check f
in
......
......@@ -86,13 +86,13 @@ let wp_forall v f =
(* if t_occurs_single v f then t_forall_close_simp [v] [] f else f *)
match f.t_node with
| Tbinop (Timplies, {t_node = Tapp (s,[{t_node = Tvar u};r])},h)
when ls_equal s ps_equ && vs_equal u v && not (t_occurs_single v r) ->
when ls_equal s ps_equ && vs_equal u v && not (Mvs.mem v r.t_vars) ->
t_let_close_simp v r h
| Tbinop (Timplies, {t_node = Tbinop (Tand, g,
{t_node = Tapp (s,[{t_node = Tvar u};r])})},h)
when ls_equal s ps_equ && vs_equal u v && not (t_occurs_single v r) ->
when ls_equal s ps_equ && vs_equal u v && not (Mvs.mem v r.t_vars) ->
t_let_close_simp v r (t_implies_simp g h)
| _ when t_occurs_single v f ->
| _ when Mvs.mem v f.t_vars ->
t_forall_close_simp [v] [] f
| _ ->
f
......
......@@ -55,7 +55,7 @@ let is_lambda t = destruct_lambda t <> LNone
let rec rewriteT t =
match t.t_node with
| Teps fb when is_lambda t ->
let fv = Svs.elements (t_freevars Svs.empty t) in
let fv = Mvs.keys t.t_vars in
let x, f = t_open_bound fb in
let f = rewriteF f in
if fv = [] then t_eps_close x f
......
......@@ -98,7 +98,7 @@ let rec rewriteT kn state t = match t.t_node with
and rewriteF kn state av sign f = match f.t_node with
| Tcase (t1,bl) ->
let t1 = rewriteT kn state t1 in
let av' = Svs.diff av (t_freevars Svs.empty t1) in
let av' = Mvs.set_diff av t1.t_vars in
let mk_br (w,m) br =
let (p,e) = t_open_branch br in
let e = rewriteF kn state av' sign e in
......@@ -148,7 +148,7 @@ and rewriteF kn state av sign f = match f.t_node with
TermTF.t_map_sign (const (rewriteT kn state))
(rewriteF kn state av) sign f
| Tlet (t1, _) ->
let av = Svs.diff av (t_freevars Svs.empty t1) in
let av = Mvs.set_diff av t1.t_vars in
TermTF.t_map_sign (const (rewriteT kn state))
(rewriteF kn state av) sign f
| _ ->
......
......@@ -70,7 +70,7 @@ module Transform = struct
let type_close_select tvs ts fn f =
let fold acc t = extract_tvar acc (app_type t) (t_type t) in
let tvm = List.fold_left fold Mtv.empty ts in
let tvs = Mtv.diff (const3 None) tvs tvm in
let tvs = Mtv.set_diff tvs tvm in
let get_vs tv = create_vsymbol (id_clone tv.tv_name) ty_type in
let tvm' = Mtv.mapi (fun v () -> get_vs v) tvs in
let vl = Mtv.values tvm' in
......@@ -143,7 +143,7 @@ module Transform = struct
(* Debug.print_list Pretty.print_ty Format.std_formatter type_vars; *)
let tv_to_ty = ty_match Mtv.empty (of_option lsymbol.ls_value) ty in
let new_ty = type_variable_only_in_value lsymbol in
let tv_to_ty = Mtv.inter (fun _ tv () -> Some tv) tv_to_ty new_ty in
let tv_to_ty = Mtv.set_inter tv_to_ty new_ty in
(* Debug.print_mtv Pretty.print_ty Format.err_formatter tv_to_ty; *)
let args = List.map (term_transform kept varM) args in
(* fresh args to be added at the beginning of the list of arguments *)
......
......@@ -111,7 +111,7 @@ let trivial tl =
let add vs t = match t.t_node with
| Tvar v when Mvs.mem v vs -> raise Util.FoldSkip
| Tvar v -> Svs.add v vs
| _ when Svs.is_empty (t_freevars Svs.empty t) -> vs
| _ when Mvs.is_empty t.t_vars -> vs
| _ -> raise Util.FoldSkip
in
try ignore (List.fold_left add Svs.empty tl); true
......
......@@ -31,7 +31,7 @@ let lift kind =
let rec term acc t =
match t.t_node with
| Teps fb ->
let fv = Svs.elements (t_freevars Svs.empty t) in
let fv = Mvs.keys t.t_vars in
let x, f = t_open_bound fb in
let acc, f = form acc f in
let tys = List.map (fun x -> x.vs_ty) fv in
......
......@@ -57,6 +57,9 @@ module type S =
val inter : (key -> 'a -> 'b -> 'c option) -> 'a t -> 'b t -> 'c t
val diff : (key -> 'a -> 'b -> 'a option) -> 'a t -> 'b t -> 'a t
val submap : (key -> 'a -> 'b -> bool) -> 'a t -> 'b t -> bool
val set_inter : 'a t -> 'b t -> 'a t
val set_diff : 'a t -> 'b t -> 'a t
val set_submap : 'a t -> 'b t -> bool
val find_default : key -> 'a -> 'a t -> 'a
val find_option : key -> 'a t -> 'a option
val map_filter: ('a -> 'b option) -> 'a t -> 'b t
......@@ -477,6 +480,11 @@ module Make(Ord: OrderedType) = struct
submap pr (Node (Empty, v1, d1, r1, 0)) r2 && submap pr l1 t2
let set_inter m1 m2 = inter (fun _ x _ -> Some x) m1 m2
let set_diff m1 m2 = diff (fun _ _ _ -> None) m1 m2
let set_submap m1 m2 = submap (fun _ _ _ -> true) m1 m2
let rec find_default x def = function
Empty -> def
| Node(l, v, d, r, _) ->
......
......@@ -208,6 +208,15 @@ module type S =
(** [submap pr m1 m2] verifies that all the keys in m1 are in m2
and that for each such binding pr is verified. *)
val set_inter : 'a t -> 'b t -> 'a t
(** [set_inter = inter (fun _ x _ -> Some x)] *)
val set_diff : 'a t -> 'b t -> 'a t
(** [set_diff = diff (fun _ _ _ -> None)] *)
val set_submap : 'a t -> 'b t -> bool
(** [set_submap = submap (fun _ _ _ -> true)] *)
val find_default : key -> 'a -> 'a t -> 'a
(** [find_default x d m] returns the current binding of [x] in [m],
or return [d] if no such binding exists. *)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment