Commit de5783d8 authored by François Bobot's avatar François Bobot
Browse files

Add more efficient function in Map and use them in core

parent 1f836cd2
......@@ -113,7 +113,7 @@ let rec match_term vm t acc p = match t.t_node, p.pat_node with
let build_call_graph cgr syms ls =
let call vm s tl =
let desc t = match t.t_node with
| Tvar v -> (try Mvs.find v vm with Not_found -> Unknown)
| Tvar v -> Mvs.find_default v Unknown vm
| _ -> Unknown
in
Hls.add cgr s (ls, Array.of_list (List.map desc tl))
......@@ -563,13 +563,9 @@ let known_id kn id =
if not (Mid.mem id kn) then raise (UnknownIdent id)
let merge_known kn1 kn2 =
let add_known id decl kn =
try
if not (d_equal (Mid.find id kn2) decl) then raise (RedeclaredIdent id);
kn
with Not_found -> Mid.add id decl kn
in
Mid.fold add_known kn1 kn2
let check_known id decl1 decl2 =
if d_equal decl1 decl2 then Some decl1 else raise (RedeclaredIdent id) in
Mid.union check_known kn1 kn2
let known_add_decl kn0 decl =
let add_known id kn =
......
......@@ -68,7 +68,7 @@ module Compile (X : Action) = struct
(* dispatch every case to a primitive constructor/wild case *)
let cases,wilds =
let add_case fs pl a cases =
let rl = try Mls.find fs cases with Not_found -> [] in
let rl = Mls.find_default fs [] cases in
Mls.add fs ((pl,a)::rl) cases
in
let add_wild pl a fs ql cases =
......
......@@ -118,7 +118,7 @@ let print_th_prelude task fmt pm =
| _ -> acc) [] task
in
List.iter (fun th ->
let prel = try Mid.find th.th_name pm with Not_found -> [] in
let prel = Mid.find_default th.th_name [] pm in
print_prelude fmt prel) th_used
exception KnownTypeSyntax of tysymbol
......
......@@ -54,12 +54,15 @@ let tds_hash tds = Hashweak.tag_hash tds.tds_tag
type clone_map = tdecl_set Mid.t
type meta_map = tdecl_set Mmeta.t
let cm_find cm th = try Mid.find th.th_name cm with Not_found -> empty_tds
let cm_find cm th = Mid.find_default th.th_name empty_tds cm
let mm_find mm t =
try Mmeta.find t mm with Not_found -> empty_tds
let mm_find mm t = Mmeta.find_default t empty_tds mm
let cm_add cm th td = Mid.add th.th_name (tds_add td (cm_find cm th)) cm
let cm_add cm th td = (* Mid.add th.th_name (tds_add td (cm_find cm th)) cm *)
Mid.change th.th_name
(function
| None -> Some (tds_singleton td)
| Some tds -> Some (tds_add td tds)) cm
let mm_add mm t td = if t.meta_excl
then Mmeta.add t (tds_singleton td) mm
......
......@@ -1572,9 +1572,14 @@ let rec t_match s t1 t2 =
if not (t_equal t1.t_ty t2.t_ty) then raise NoMatch else
match t1.t_node, t2.t_node with
| Tconst c1, Tconst c2 when c1 = c2 -> s
| Tvar v1, _ -> begin try
if t_equal (Mvs.find v1 s) t2 then s else raise NoMatch
with Not_found -> Mvs.add v1 t2 s end
| Tvar v1, _ ->
Mvs.change v1 (function
| None -> Some t2
| Some tv1 when t_equal tv1 t2 -> Some tv1
| _ -> raise NoMatch) s
(* begin try *)
(* if t_equal (Mvs.find v1 s) t2 then s else raise NoMatch *)
(* with Not_found -> Mvs.add v1 t2 s end *)
| Tapp (s1,l1), Tapp (s2,l2) when ls_equal s1 s2 ->
List.fold_left2 t_match s l1 l2
| Tif (f1,t1,e1), Tif (f2,t2,e2) ->
......
......@@ -47,10 +47,10 @@ exception ClashSymbol of string
let ns_add eq chk x v m =
if not chk then Mnm.add x v m
else try
if not (eq (Mnm.find x m) v) then raise (ClashSymbol x);
m
with Not_found -> Mnm.add x v m
else Mnm.change x (function
| None -> Some v
| Some vm when eq vm v -> Some vm
| _ -> raise (ClashSymbol x)) m
let ts_add = ns_add ts_equal
let ls_add = ns_add ls_equal
......@@ -63,8 +63,9 @@ let rec merge_ns chk ns1 ns2 =
ns_ns = Mnm.fold (fusion chk) ns1.ns_ns ns2.ns_ns; }
and fusion chk x ns m =
let os = try Mnm.find x m with Not_found -> empty_ns in
Mnm.add x (merge_ns chk os ns) m
Mnm.change x (function
| None -> Some (merge_ns chk empty_ns ns)
| Some os -> Some (merge_ns chk os ns)) m
let add_ts chk x ts ns = { ns with ns_ts = ts_add chk x ts ns.ns_ts }
let add_ls chk x ls ns = { ns with ns_ls = ls_add chk x ls ns.ns_ls }
......@@ -675,9 +676,9 @@ let add_meta uc s al = add_tdecl uc (create_meta s al)
let clone_meta tdt th tdc = match tdt.td_node, tdc.td_node with
| Meta (t,al), Clone (th',tm,lm,pm) when id_equal th.th_name th'.th_name ->
let find_ts ts = try Mts.find ts tm with Not_found -> ts in
let find_ls ls = try Mls.find ls lm with Not_found -> ls in
let find_pr pr = try Mpr.find pr pm with Not_found -> pr in
let find_ts ts = Mts.find_default ts ts tm in
let find_ls ls = Mls.find_default ls ls lm in
let find_pr pr = Mpr.find_default pr pr pm in
let cl_marg = function
| MAts ts -> MAts (find_ts ts)
| MAls ls -> MAls (find_ls ls)
......
......@@ -189,16 +189,17 @@ let ty_s_any pr ty =
(* type matching *)
let rec ty_inst s ty = match ty.ty_node with
| Tyvar n -> (try Mtv.find n s with Not_found -> ty)
| Tyvar n -> Mtv.find_default n ty s
| _ -> ty_map (ty_inst s) ty
let rec ty_match s ty1 ty2 =
if ty_equal ty1 ty2 then s
else match ty1.ty_node, ty2.ty_node with
| Tyvar n1, _ ->
(try if ty_equal (Mtv.find n1 s) ty2
then s else raise Exit
with Not_found -> Mtv.add n1 ty2 s)
Mtv.change n1 (function
| None -> Some ty2
| Some ty1 as r when ty_equal ty1 ty2 -> r
| _ -> raise Exit) s
| Tyapp (f1, l1), Tyapp (f2, l2) when ts_equal f1 f2 ->
List.fold_left2 ty_match s l1 l2
| _ ->
......
......@@ -50,6 +50,12 @@ module type S =
val find: key -> 'a t -> 'a
val map: ('a -> 'b) -> 'a t -> 'b t
val mapi: (key -> 'a -> 'b) -> 'a t -> 'b t
(** Added into why stdlib version *)
val change : key -> ('a option -> 'a option) -> 'a t -> 'a t
val union : (key -> 'a -> 'a -> 'a option) -> 'a t -> 'a t -> 'a t
val inter : (key -> 'a -> 'a -> 'a option) -> 'a t -> 'a t -> 'a t
val find_default : key -> 'a -> 'a t -> 'a
end
module Make(Ord: OrderedType) = struct
......@@ -315,6 +321,73 @@ module Make(Ord: OrderedType) = struct
let choose = min_binding
(** Added into why stdlib version *)
let rec change x f = function
| Empty ->
begin match f None with
| None -> Empty
| Some d -> Node(Empty, x, d, Empty, 1)
end
| Node(l, v, d, r, h) ->
let c = Ord.compare x v in
if c = 0 then
(* concat or bal *)
match f (Some d) with
| None -> concat l r
| Some d -> Node(l, x, d, r, h)
else if c < 0 then
bal (change x f l) v d r
else
bal l v d (change x f r)
let rec union f s1 s2 =
match (s1, s2) with
(Empty, t2) -> t2
| (t1, Empty) -> t1
| (Node(l1, v1, d1, r1, h1), Node(l2, v2, d2, r2, h2)) ->
if h1 >= h2 then
if h2 = 1 then
change v2 (function None -> Some d2 | Some d1 -> f v2 d1 d2) s1
else begin
let (l2, d2, r2) = split v1 s2 in
match d2 with
| None -> join (union f l1 l2) v1 d1 (union f r1 r2)
| Some d2 ->
concat_or_join (union f l1 l2) v1 (f v1 d1 d2)
(union f r1 r2)
end
else
if h1 = 1 then
change v1 (function None -> Some d1 | Some d2 -> f v1 d1 d2) s2
else begin
let (l1, d1, r1) = split v2 s1 in
match d1 with
| None -> join (union f l1 l2) v2 d2 (union f r1 r2)
| Some d1 ->
concat_or_join (union f l1 l2) v2 (f v2 d1 d2)
(union f r1 r2)
end
let rec inter f s1 s2 =
match (s1, s2) with
| (Empty, _) | (_, Empty) -> Empty
| (Node(l1, v1, d1, r1, _), t2) ->
match split v1 t2 with
(l2, None, r2) ->
concat (inter f l1 l2) (inter f r1 r2)
| (l2, Some d2, r2) ->
concat_or_join (inter f l1 l2) v1 (f v1 d1 d2) (inter f r1 r2)
let rec find_default x def = function
Empty -> def
| Node(l, v, d, r, _) ->
let c = Ord.compare x v in
if c = 0 then d
else find_default x def (if c < 0 then l else r)
end
end
......@@ -188,6 +188,35 @@ module type S =
key and the associated value for each binding of the map. *)
(** {3} Added into why stdlib version *)
val change : key -> ('a option -> 'a option) -> 'a t -> 'a t
(** [change x f m] returns a map containing the same bindings as
[m], except the binding of [x] in [m] is changed from [y] to
[f (Some y)] if [m] contains a binding of [x], otherwise the
binding of [x] becomes [f None].
[change x f m] corresponds to a more efficient way to do
[add x (try f (Some (find x m)) with Not_found -> f None) m]
*)
val union : (key -> 'a -> 'a -> 'a option) -> 'a t -> 'a t -> 'a t
(** [union f m1 m2] computes a map whose keys is a subset of keys of [m1]
and of [m2]. If a binding is present in [m1] (resp. [m2]) and not in
[m2] (resp. [m1]) the same binding is present in the result. Indeed the
function [f] is called only in ambiguous cases.
*)
val inter : (key -> 'a -> 'a -> 'a option) -> 'a t -> 'a t -> 'a t
(** [inter f m1 m2] computes a map whose keys is a subset of keys of [m1]
and of [m2].
*)
val find_default : key -> 'a -> 'a t -> 'a
(** [find_default x d m] returns the current binding of [x] in [m],
or return [d] if no such binding exists. *)
end
(** Output signature of the functor {!Map.Make}. *)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment