OST.ml 2.89 KB
Newer Older
POTTIER Francois's avatar
POTTIER Francois committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
(* This is a variant of OCaml's [Set] module, where each node in the binary
   search tree carries its size (i.e., the number of its elements). The tree
   is thus an ordered statistics tree, and supports [select] and [rank] in
   logarithmic time and [cardinal] in constant time. *)

(* This implementation is minimalistic -- many set operations are missing. *)

module Make (Ord : Set.OrderedType) = struct

  type elt = Ord.t
  type t = Empty | Node of t * elt * t * (* height: *) int * (* size: *) int

  let height = function
      Empty -> 0
    | Node(_, _, _, h, _) -> h

  let size = function
      Empty -> 0
    | Node(_, _, _, _, s) -> s

  let create l v r =
    let hl = height l in
    let hr = height r in
    Node(l, v, r, (if hl >= hr then hl + 1 else hr + 1), size l + 1 + size r)

  let bal l v r =
    let hl = height l in
    let hr = height r in
    if hl > hr + 2 then begin
      match l with
        Empty -> invalid_arg "Set.bal"
      | Node(ll, lv, lr, _, _) ->
          if height ll >= height lr then
            create ll lv (create lr v r)
          else begin
            match lr with
              Empty -> invalid_arg "Set.bal"
            | Node(lrl, lrv, lrr, _, _)->
                create (create ll lv lrl) lrv (create lrr v r)
          end
    end else if hr > hl + 2 then begin
      match r with
        Empty -> invalid_arg "Set.bal"
      | Node(rl, rv, rr, _, _) ->
          if height rr >= height rl then
            create (create l v rl) rv rr
          else begin
            match rl with
              Empty -> invalid_arg "Set.bal"
            | Node(rll, rlv, rlr, _, _) ->
                create (create l v rll) rlv (create rlr rv rr)
          end
    end else
      Node(l, v, r, (if hl >= hr then hl + 1 else hr + 1), size l + 1 + size r)

  let rec add x = function
      Empty -> Node(Empty, x, Empty, 1, 1)
    | Node(l, v, r, _, _) as t ->
        let c = Ord.compare x v in
        if c = 0 then t else
        if c < 0 then
          let ll = add x l in
          if l == ll then t else bal ll v r
        else
          let rr = add x r in
          if r == rr then t else bal l v rr

  let empty = Empty

  let cardinal = size

  let rec select i = function
    | Empty ->
        (* [i] is out of bounds *)
        assert false
    | Node (l, v, r, _, s) ->
        assert (0 <= i && i < s);
        let sl = size l in
        if i < sl then
          select i l
        else if i = sl then
          v
        else
          select (i - sl - 1) r

  let pick xs =
    let s = size xs in
    if s = 0 then
      raise Not_found
    else
      select (Random.int s) xs

  let rec rank accu x = function
    | Empty ->
        raise Not_found
    | Node (l, v, r, _, _) ->
        let c = Ord.compare x v in
        if c = 0 then
          accu + size l
        else if c < 0 then
          rank accu x l
        else
          rank (accu + size l + 1) x r

  let rank =
    rank 0

end