module Make (Ord: Map.OrderedType) =
struct
type elt = Ord.t
type t = Empty | Node of t * elt * t * int
(* Sets are represented by balanced binary trees (the heights of the
children differ by at most 2 *)
let height = function
Empty -> 0
| Node(_, _, _, h) -> h
(* Creates a new node with left son l, value v and right son r.
We must have all elements of l < v < all elements of r.
l and r must be balanced and | height l - height r | <= 2.
Inline expansion of height for better speed. *)
let create l v r =
let hl = match l with Empty -> 0 | Node(_,_,_,h) -> h in
let hr = match r with Empty -> 0 | Node(_,_,_,h) -> h in
Node(l, v, r, (if hl >= hr then hl + 1 else hr + 1))
(* Same as create, but performs one step of rebalancing if necessary.
Assumes l and r balanced and | height l - height r | <= 3.
Inline expansion of create for better speed in the most frequent case
where no rebalancing is required. *)
let bal l v r =
let hl = match l with Empty -> 0 | Node(_,_,_,h) -> h in
let hr = match r with Empty -> 0 | Node(_,_,_,h) -> h in
if hl > hr + 2 then begin
match l with
Empty -> invalid_arg "Set.bal"
| Node(ll, lv, lr, _) ->
if height ll >= height lr then
create ll lv (create lr v r)
else begin
match lr with
Empty -> invalid_arg "Set.bal"
| Node(lrl, lrv, lrr, _)->
create (create ll lv lrl) lrv (create lrr v r)
end
end else if hr > hl + 2 then begin
match r with
Empty -> invalid_arg "Set.bal"
| Node(rl, rv, rr, _) ->
if height rr >= height rl then
create (create l v rl) rv rr
else begin
match rl with
Empty -> invalid_arg "Set.bal"
| Node(rll, rlv, rlr, _) ->
create (create l v rll) rlv (create rlr rv rr)
end
end else
Node(l, v, r, (if hl >= hr then hl + 1 else hr + 1))
(* [add x t] guarantees that it returns [t] (physically unchanged)
if [x] is already a member of [t]. *)
let rec add x = function
Empty -> Node(Empty, x, Empty, 1)
| Node(l, v, r, _) as t ->
let c = Ord.compare x v in
if c = 0 then t else
if c < 0 then
let l' = add x l in
if l == l' then t
else bal l' v r
else
let r' = add x r in
if r == r' then t
else bal l v r'
let empty = Empty
let rec find x = function
Empty -> raise Not_found
| Node(l, v, r, _) ->
let c = Ord.compare x v in
if c = 0 then v
else find x (if c < 0 then l else r)
let rec iter f = function
Empty -> ()
| Node(l, v, r, _) -> iter f l; f v; iter f r
end