Commit f7544e06 authored by Andrei Paskevich's avatar Andrei Paskevich

avoid unnecessary h-consing in t_map and t_subst

introduce Map.disjoint, Map.set_disjoint and Set.disjoint
parent 211a72f2
......@@ -496,15 +496,30 @@ let bound_map fn (u,b,e) = (u, bnd_map fn b, fn e)
let t_map_unsafe fn t = t_label_copy t (match t.t_node with
| Tvar _ | Tconst _ -> t
| Tapp (f,tl) -> t_app f (List.map fn tl) t.t_ty
| Tif (f,t1,t2) -> t_if (fn f) (fn t1) (fn t2)
| Tlet (e,b) -> t_let (fn e) (bound_map fn b) t.t_ty
| Tcase (e,bl) -> t_case (fn e) (List.map (bound_map fn) bl) t.t_ty
| Teps b -> t_eps (bound_map fn b) t.t_ty
| Tapp (f,tl) ->
let sl = List.map fn tl in
if List.for_all2 t_equal sl tl then t else
t_app f sl t.t_ty
| Tif (f,t1,t2) ->
let g = fn f and s1 = fn t1 and s2 = fn t2 in
if t_equal g f && t_equal s1 t1 && t_equal s2 t2 then t else
t_if g s1 s2
| Tlet (e,b) ->
t_let (fn e) (bound_map fn b) t.t_ty
| Tcase (e,bl) ->
t_case (fn e) (List.map (bound_map fn) bl) t.t_ty
| Teps b ->
t_eps (bound_map fn b) t.t_ty
| Tquant (q,(vl,b,tl,f1)) ->
t_quant q (vl, bnd_map fn b, tr_map fn tl, fn f1)
| Tbinop (op,f1,f2) -> t_binary op (fn f1) (fn f2)
| Tnot f1 -> t_not (fn f1)
| Tbinop (op,f1,f2) ->
let g1 = fn f1 and g2 = fn f2 in
if t_equal g1 f1 && t_equal g2 f2 then t else
t_binary op g1 g2
| Tnot f1 ->
let g1 = fn f1 in
if t_equal g1 f1 then t else
t_not g1
| Ttrue | Tfalse -> t)
(* unsafe fold *)
......@@ -534,13 +549,15 @@ let t_map_fold_unsafe fn acc t = match t.t_node with
| Tvar _ | Tconst _ ->
acc, t
| Tapp (f,tl) ->
let acc,tl = map_fold_left fn acc tl in
acc, t_label_copy t (t_app f tl t.t_ty)
let acc,sl = map_fold_left fn acc tl in
if List.for_all2 t_equal sl tl then acc,t else
acc, t_label_copy t (t_app f sl t.t_ty)
| Tif (f,t1,t2) ->
let acc, f = fn acc f in
let acc, t1 = fn acc t1 in
let acc, t2 = fn acc t2 in
acc, t_label_copy t (t_if f t1 t2)
let acc, g = fn acc f in
let acc, s1 = fn acc t1 in
let acc, s2 = fn acc t2 in
if t_equal g f && t_equal s1 t1 && t_equal s2 t2 then acc,t else
acc, t_label_copy t (t_if g s1 s2)
| Tlet (e,b) ->
let acc, e = fn acc e in
let acc, b = bound_map_fold fn acc b in
......@@ -558,12 +575,14 @@ let t_map_fold_unsafe fn acc t = match t.t_node with
let acc, f1 = fn acc f1 in
acc, t_label_copy t (t_quant q (vl,b,tl,f1))
| Tbinop (op,f1,f2) ->
let acc, f1 = fn acc f1 in
let acc, f2 = fn acc f2 in
acc, t_label_copy t (t_binary op f1 f2)
let acc, g1 = fn acc f1 in
let acc, g2 = fn acc f2 in
if t_equal g1 f1 && t_equal g2 f2 then acc,t else
acc, t_label_copy t (t_binary op g1 g2)
| Tnot f1 ->
let acc, f1 = fn acc f1 in
acc, t_label_copy t (t_not f1)
let acc, g1 = fn acc f1 in
if t_equal g1 f1 then acc,t else
acc, t_label_copy t (t_not g1)
| Ttrue | Tfalse ->
acc, t
......@@ -572,17 +591,24 @@ let t_map_fold_unsafe fn acc t = match t.t_node with
let rec t_subst_unsafe m t =
let t_subst t = t_subst_unsafe m t in
let b_subst (u,b,e) = (u, bv_subst_unsafe m b, e) in
let nosubst (_,b,_) = Mvs.set_disjoint m b.bv_vars in
match t.t_node with
| Tvar u ->
Mvs.find_default u t m
| Tlet (e, bt) ->
t_label_copy t (t_let (t_subst e) (b_subst bt) t.t_ty)
let d = t_subst e in
if t_equal d e && nosubst bt then t else
t_label_copy t (t_let d (b_subst bt) t.t_ty)
| Tcase (e, bl) ->
let d = t_subst e in
if t_equal d e && List.for_all nosubst bl then t else
let bl = List.map b_subst bl in
t_label_copy t (t_case (t_subst e) bl t.t_ty)
t_label_copy t (t_case d bl t.t_ty)
| Teps bf ->
if nosubst bf then t else
t_label_copy t (t_eps (b_subst bf) t.t_ty)
| Tquant (q, (vl,b,tl,f1)) ->
if Mvs.set_disjoint m b.bv_vars then t else
let b = bv_subst_unsafe m b in
t_label_copy t (t_quant q (vl,b,tl,f1))
| _ ->
......
......@@ -57,9 +57,11 @@ module type S =
val inter : (key -> 'a -> 'b -> 'c option) -> 'a t -> 'b t -> 'c t
val diff : (key -> 'a -> 'b -> 'a option) -> 'a t -> 'b t -> 'a t
val submap : (key -> 'a -> 'b -> bool) -> 'a t -> 'b t -> bool
val disjoint : (key -> 'a -> 'b -> bool) -> 'a t -> 'b t -> bool
val set_inter : 'a t -> 'b t -> 'a t
val set_diff : 'a t -> 'b t -> 'a t
val set_submap : 'a t -> 'b t -> bool
val set_disjoint : 'a t -> 'b t -> bool
val find_default : key -> 'a -> 'a t -> 'a
val find_option : key -> 'a t -> 'a option
val map_filter: ('a -> 'b option) -> 'a t -> 'b t
......@@ -91,6 +93,7 @@ module type S =
val compare: t -> t -> int
val equal: t -> t -> bool
val subset: t -> t -> bool
val disjoint: t -> t -> bool
val iter: (elt -> unit) -> t -> unit
val fold: (elt -> 'a -> 'a) -> t -> 'a -> 'a
val for_all: (elt -> bool) -> t -> bool
......@@ -468,7 +471,7 @@ module Make(Ord: OrderedType) = struct
let rec submap pr s1 s2 =
match (s1, s2) with
| Empty, _ -> true
| Empty, _ -> true
| _, Empty -> false
| Node (l1, v1, d1, r1, _), (Node (l2, v2, d2, r2, _) as t2) ->
let c = Ord.compare v1 v2 in
......@@ -480,9 +483,24 @@ module Make(Ord: OrderedType) = struct
submap pr (Node (Empty, v1, d1, r1, 0)) r2 && submap pr l1 t2
let rec disjoint pr s1 s2 =
match (s1, s2) with
| Empty, _ -> true
| _, Empty -> true
| Node (l1, v1, d1, r1, _), (Node (l2, v2, d2, r2, _) as t2) ->
let c = Ord.compare v1 v2 in
if c = 0 then
pr v1 d1 d2 && disjoint pr l1 l2 && disjoint pr r1 r2
else if c < 0 then
disjoint pr (Node (l1, v1, d1, Empty, 0)) l2 && disjoint pr r1 t2
else
disjoint pr (Node (Empty, v1, d1, r1, 0)) r2 && disjoint pr l1 t2
let set_inter m1 m2 = inter (fun _ x _ -> Some x) m1 m2
let set_diff m1 m2 = diff (fun _ _ _ -> None) m1 m2
let set_submap m1 m2 = submap (fun _ _ _ -> true) m1 m2
let set_disjoint m1 m2 = disjoint (fun _ _ _ -> false) m1 m2
let rec find_default x def = function
......@@ -597,6 +615,7 @@ module Make(Ord: OrderedType) = struct
val compare: t -> t -> int
val equal: t -> t -> bool
val subset: t -> t -> bool
val disjoint: t -> t -> bool
val iter: (elt -> unit) -> t -> unit
val fold: (elt -> 'a -> 'a) -> t -> 'a -> 'a
val for_all: (elt -> bool) -> t -> bool
......@@ -639,6 +658,7 @@ module Make(Ord: OrderedType) = struct
let compare = compare (fun _ _ -> 0)
let equal = equal (fun _ _ -> true)
let subset = submap (fun _ _ _ -> true)
let disjoint = disjoint (fun _ _ _ -> false)
let iter f = iter (const f)
let fold f = fold (const f)
let for_all f = for_all (const f)
......
......@@ -208,6 +208,10 @@ module type S =
(** [submap pr m1 m2] verifies that all the keys in m1 are in m2
and that for each such binding pr is verified. *)
val disjoint : (key -> 'a -> 'b -> bool) -> 'a t -> 'b t -> bool
(** [disjoint pr m1 m2] verifies that for every common key in m1
and m2, pr is verified. *)
val set_inter : 'a t -> 'b t -> 'a t
(** [set_inter = inter (fun _ x _ -> Some x)] *)
......@@ -217,6 +221,9 @@ module type S =
val set_submap : 'a t -> 'b t -> bool
(** [set_submap = submap (fun _ _ _ -> true)] *)
val set_disjoint : 'a t -> 'b t -> bool
(** [set_disjoint = disjoint (fun _ _ _ -> false)] *)
val find_default : key -> 'a -> 'a t -> 'a
(** [find_default x d m] returns the current binding of [x] in [m],
or return [d] if no such binding exists. *)
......@@ -310,6 +317,10 @@ module type S =
val subset: t -> t -> bool
(** [subset s1 s2] tests whether the set [s1] is a subset of [s2]. *)
val disjoint: t -> t -> bool
(** [disjoint s1 s2] tests whether the sets [s1] and [s2]
are disjoint. *)
val iter: (elt -> unit) -> t -> unit
(** [iter f s] applies [f] to all elements of [s].
The elements are passed to [f] in increasing order with respect
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment