sumrange.mlw 12.1 KB
Newer Older
1

2 3 4 5 6
(** {1 Range Sum Queries}

We are interested in specifying and proving correct
data structures that support efficient computation of the sum of the
values over an arbitrary range of an array.
7 8 9
Concretely, given an array of integers `a`, and given a range
delimited by indices `i` (inclusive) and `j` (exclusive), we wish
to compute the value: `\sum_{k=i}^{j-1} a[k]`.
10 11 12 13 14 15 16 17

In the first part, we consider a simple loop
for computing the sum in linear time.

In the second part, we introduce a cumulative sum array
that allows answering arbitrary range queries in constant time.

In the third part, we explore a tree data structure that
18
supports modification of values from the underlying array `a`,
19 20 21 22 23 24
with logarithmic time operations.

*)


(** {2 Specification of Range Sum} *)
25 26 27

module ArraySum

28 29
  use int.Int
  use array.Array
30

31
  (** `sum a i j` denotes the sum `\sum_{i <= k < j} a[k]`.
32 33 34
      It is axiomatizated by the two following axioms expressing
      the recursive definition

35
      if `i <= j` then `sum a i j = 0`
36

37
      if `i < j` then `sum a i j = a[i] + sum a (i+1) j`
38 39

  *)
40
  let rec function sum (a:array int) (i j:int) : int
41 42 43
   requires { 0 <= i <= j <= a.length }
   variant { j - i }
   = if j <= i then 0 else a[i] + sum a (i+1) j
44

45 46
  (** lemma for summation from the right:

47
      if `i < j` then `sum a i j = sum a i (j-1) + a[j-1]`
48 49

 *)
50 51 52 53 54 55 56 57
  lemma sum_right : forall a : array int, i j : int.
    0 <= i < j <= a.length  ->
    sum a i j = sum a i (j-1) + a[j-1]

end



58 59 60

(** {2 First algorithm, a linear one} *)

61 62
module Simple

63 64 65 66
  use int.Int
  use array.Array
  use ArraySum
  use ref.Ref
67

68 69
  (** `query a i j` returns the sum of elements in `a` between
      index `i` inclusive and index `j` exclusive *)
70 71 72 73 74 75 76 77 78 79 80 81 82
  let query (a:array int) (i j:int) : int
    requires { 0 <= i <= j <= a.length }
    ensures { result = sum a i j }
  = let s = ref 0 in
    for k=i to j-1 do
      invariant { !s = sum a i k }
      s := !s + a[k]
    done;
    !s

end


83 84


85
(** {2 Additional lemmas on `sum`}
86 87 88 89
  needed in the remaining code *)

module ExtraLemmas

90 91 92
  use int.Int
  use array.Array
  use ArraySum
93 94 95 96 97 98

  (** summation in adjacent intervals *)
  lemma sum_concat : forall a:array int, i j k:int.
    0 <= i <= j <= k <= a.length ->
    sum a i k = sum a i j + sum a j k

99 100
  (** Frame lemma for `sum`, that is `sum a i j` depends only
      of values of `a[i..j-1]` *)
101 102 103 104 105 106 107
  lemma sum_frame : forall a1 a2 : array int, i j : int.
    0 <= i <= j ->
    j <= a1.length ->
    j <= a2.length ->
    (forall k : int. i <= k < j -> a1[k] = a2[k]) ->
    sum a1 i j = sum a2 i j

108 109
  (** Updated lemma for `sum`: how does `sum a i j` changes when
      `a[k]` is changed for some `k` in `[i..j-1]` *)
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
  lemma sum_update : forall a:array int, i v l h:int.
    0 <= l <= i < h <= a.length ->
    sum (a[i<-v]) l h = sum a l h + v - a[i]


end




(** {2 Algorithm 2: using a cumulative array}

   creation of cumulative array is linear

   query is in constant time

   array update is linear

*)


131 132
module CumulativeArray

133 134 135 136
  use int.Int
  use array.Array
  use ArraySum
  use ExtraLemmas
137 138 139 140 141

  predicate is_cumulative_array_for (c:array int) (a:array int) =
    c.length = a.length + 1 /\
    forall i. 0 <= i < c.length -> c[i] = sum a 0 i

142
  (** `create a` builds the cumulative array associated with `a`. *)
143 144 145 146 147 148 149 150 151 152
  let create (a:array int) : array int
    ensures { is_cumulative_array_for result a }
  = let l = a.length in
    let s = Array.make (l+1) 0 in
    for i=1 to l do
      invariant { forall k. 0 <= k < i -> s[k] = sum a 0 k }
      s[i] <- s[i-1] + a[i-1]
    done;
    s

153 154
  (** `query c i j a` returns the sum of elements in `a` between
      index `i` inclusive and index `j` exclusive, in constant time *)
155 156 157 158 159 160 161
  let query (c:array int) (i j:int) (ghost a:array int): int
    requires { is_cumulative_array_for c a }
    requires { 0 <= i <= j < c.length }
    ensures { result = sum a i j }
  = c[j] - c[i]


162 163
  (** `update c i v a` updates cell `a[i]` to value `v` and updates
      the cumulative array `c` accordingly *)
164
  let update (c:array int) (i:int) (v:int) (ghost a:array int) : unit
165 166 167
    requires { is_cumulative_array_for c a }
    requires { 0 <= i < a.length }
    writes  { c, a }
168 169 170 171
    ensures { is_cumulative_array_for c a }
    ensures { a[i] = v }
    ensures { forall k. 0 <= k < a.length /\ k <> i ->
              a[k] = (old a)[k] }
172
  = let incr = v - c[i+1] + c[i] in
173 174 175 176 177 178 179 180 181 182 183 184 185 186
    a[i] <- v;
    for j=i+1 to c.length-1 do
      invariant { forall k. j <= k < c.length -> c[k] = sum a 0 k - incr }
      invariant { forall k. 0 <= k < j -> c[k] = sum a 0 k }
      c[j] <- c[j] + incr
    done

end






187 188 189 190 191 192 193 194 195 196 197 198 199
(** {2 Algorithm 3: using a cumulative tree}

  creation is linear

  query is logarithmic

  update is logarithmic

*)




200 201
module CumulativeTree

202 203 204 205 206
  use int.Int
  use array.Array
  use ArraySum
  use ExtraLemmas
  use int.ComputerDivision
207 208 209 210 211 212 213 214 215

  type indexes =
    { low : int;
      high : int;
      isum : int;
    }

  type tree = Leaf indexes | Node indexes tree tree

216
  let function indexes (t:tree) : indexes =
217 218 219 220 221 222 223 224 225 226 227 228 229
    match t with
    | Leaf ind -> ind
    | Node ind _ _ -> ind
    end

  predicate is_indexes_for (ind:indexes) (a:array int) (i j:int) =
    ind.low = i /\ ind.high = j /\
    0 <= i < j <= a.length /\
    ind.isum = sum a i j

  predicate is_tree_for (t:tree) (a:array int) (i j:int) =
    match t with
    | Leaf ind ->
230
        is_indexes_for ind a i j /\ j = i+1
231
    | Node ind l r ->
232
        is_indexes_for ind a i j /\
233 234 235 236 237
        i = l.indexes.low /\ j = r.indexes.high /\
        let m = l.indexes.high in
        m = r.indexes.low /\
        i < m < j /\ m = div (i+j) 2 /\
        is_tree_for l a i m /\
238
        is_tree_for r a m j
239 240
    end

241
  (** {3 creation of cumulative tree} *)
242 243

  let rec tree_of_array (a:array int) (i j:int) : tree
244 245 246
    requires { 0 <= i < j <= a.length }
    variant { j - i }
    ensures { is_tree_for result a i j }
247
    = if i+1=j then begin
248
       Leaf { low = i; high = j; isum = a[i] }
249 250 251 252 253 254 255 256 257 258 259 260 261 262
       end
      else
        begin
        let m = div (i+j) 2 in
        assert { i < m < j };
        let l = tree_of_array a i m in
        let r = tree_of_array a m j in
        let s = l.indexes.isum + r.indexes.isum in
        assert { s = sum a i j };
        Node { low = i; high = j; isum = s} l r
        end


  let create (a:array int) : tree
263
    requires { a.length >= 1 }
264
    ensures { is_tree_for result a 0 a.length }
265 266
  = tree_of_array a 0 a.length

267

268
(** {3 query using cumulative tree} *)
269 270 271 272


  let rec query_aux (t:tree) (ghost a: array int)
      (i j:int) : int
273 274 275 276 277
    requires { is_tree_for t a t.indexes.low t.indexes.high }
    requires { 0 <= t.indexes.low <= i < j <= t.indexes.high <= a.length }
    variant { t }
    ensures { result = sum a i j }
  = match t with
278 279 280 281 282 283 284
    | Leaf ind ->
      ind.isum
    | Node ind l r ->
      let k1 = ind.low in
      let k3 = ind.high in
      if i=k1 && j=k3 then ind.isum else
      let m = l.indexes.high in
285 286 287
      if j <= m then query_aux l a i j else
      if i >= m then query_aux r a i j else
      query_aux l a i m + query_aux r a m j
288 289 290 291 292
    end


  let query (t:tree) (ghost a: array int) (i j:int) : int
    requires { 0 <= i <= j <= a.length }
293
    requires { is_tree_for t a 0 a.length }
294
    ensures { result = sum a i j }
295
  = if i=j then 0 else query_aux t a i j
296 297


298
  (** frame lemma for predicate `is_tree_for` *)
299 300 301 302 303
  lemma is_tree_for_frame : forall t:tree, a:array int, k v i j:int.
    0 <= k < a.length ->
    k < i \/ k >= j ->
    is_tree_for t a i j ->
    is_tree_for t a[k<-v] i j
304

305
(** {3 update cumulative tree} *)
306 307


308 309
  let rec update_aux
      (t:tree) (i:int) (ghost a :array int) (v:int) : (t': tree, delta: int)
310 311 312
    requires { is_tree_for t a t.indexes.low t.indexes.high }
    requires { t.indexes.low <= i < t.indexes.high }
    variant { t }
313
    ensures {
314
        delta = v - a[i] /\
315 316
        t'.indexes.low = t.indexes.low /\
        t'.indexes.high = t.indexes.high /\
317 318
        is_tree_for t' a[i<-v] t'.indexes.low t'.indexes.high }
  = match t with
319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
    | Leaf ind ->
        assert { i = ind.low };
        (Leaf { ind with isum = v }, v - ind.isum)
    | Node ind l r ->
        let m = l.indexes.high in
      if i < m then
        let l',delta = update_aux l i a v in
        assert { is_tree_for l' a[i<-v] t.indexes.low m };
        assert { is_tree_for r a[i<-v] m t.indexes.high };
        (Node {ind with isum = ind.isum + delta } l' r, delta)
      else
        let r',delta = update_aux r i a v in
        assert { is_tree_for l a[i<-v] t.indexes.low m };
        assert { is_tree_for r' a[i<-v] m t.indexes.high };
        (Node {ind with isum = ind.isum + delta} l r',delta)
334
    end
335 336

  let update (t:tree) (ghost a:array int) (i v:int) : tree
337 338
     requires { 0 <= i < a.length }
     requires { is_tree_for t a 0 a.length }
339 340 341 342
     writes { a }
     ensures { a[i] = v }
     ensures { forall k. 0 <= k < a.length /\ k <> i -> a[k] = (old a)[k] }
     ensures { is_tree_for result a 0 a.length }
343
  = let t,_ = update_aux t i a v in
344
    assert { is_tree_for t a[i <- v] 0 a.length };
345 346 347 348
    a[i] <- v;
    t


349 350
(** {2 complexity analysis}

351
  We would like to prove that `query` is really logarithmic. This is
352 353
  non-trivial because there are two recursive calls in some cases.

354
  So far, we are only able to prove that `update` is logarithmic
355

356
  We express the complexity by passing a "credit" as a ghost
357 358 359 360
  parameter. We pose the precondition that the credit is at least
  equal to the depth of the tree.

*)
361

362 363 364 365
  (** preliminaries: definition of the depth of a tree, and showing
      that it is indeed logarithmic in function of the number of its
      elements *)

366
  use int.MinMax
367 368 369

  function depth (t:tree) : int =
    match t with
370
    | Leaf _ -> 1
371 372 373
    | Node _ l r -> 1 + max (depth l) (depth r)
    end

374
  lemma depth_min : forall t. depth t >= 1
375

376
  use bv.Pow2int
377 378 379 380 381 382 383 384 385 386 387 388 389 390 391

  let rec lemma depth_is_log (t:tree) (a :array int) (k:int)
     requires { k >= 0 }
     requires { is_tree_for t a t.indexes.low t.indexes.high }
     requires { t.indexes.high - t.indexes.low <= pow2 k }
     variant { t }
     ensures { depth t <= k+1 }
  = match t with
    | Leaf _ -> ()
    | Node _ l r ->
       depth_is_log l a (k-1);
       depth_is_log r a (k-1)
    end


392
  (** `update_aux` function instrumented with a credit *)
393

394
  use ref.Ref
395

396 397 398
  let rec update_aux_complexity
        (t:tree) (i:int) (ghost a :array int)
        (v:int) (ghost c:ref int) : (t': tree, delta: int)
399 400 401
     requires { is_tree_for t a t.indexes.low t.indexes.high }
     requires { t.indexes.low <= i < t.indexes.high }
     variant { t }
402
     ensures { !c - old !c <= depth t }
403
     ensures {
404
        delta = v - a[i] /\
405 406
        t'.indexes.low = t.indexes.low /\
        t'.indexes.high = t.indexes.high /\
407
        is_tree_for t' a[i<-v] t'.indexes.low t'.indexes.high }
408
  = c := !c + 1;
409 410
    match t with
    | Leaf ind ->
411 412
      assert { i = ind.low };
      (Leaf { ind with isum = v }, v - ind.isum)
413
    | Node ind l r ->
414
      let m = l.indexes.high in
415
      if i < m then
416
        let l',delta = update_aux_complexity l i a v c in
417 418 419 420
        assert { is_tree_for l' a[i<-v] t.indexes.low m };
        assert { is_tree_for r a[i<-v] m t.indexes.high };
        (Node {ind with isum = ind.isum + delta } l' r, delta)
      else
421
        let r',delta = update_aux_complexity r i a v c in
422 423 424 425 426
        assert { is_tree_for l a[i<-v] t.indexes.low m };
        assert { is_tree_for r' a[i<-v] m t.indexes.high };
        (Node {ind with isum = ind.isum + delta} l r',delta) (*>*)
    end

427
  (** `query_aux` function instrumented with a credit *)
428

429 430 431 432 433
  let rec query_aux_complexity (t:tree) (ghost a: array int)
      (i j:int) (ghost c:ref int) : int
    requires { is_tree_for t a t.indexes.low t.indexes.high }
    requires { 0 <= t.indexes.low <= i < j <= t.indexes.high <= a.length }
    variant { t }
434
    ensures { !c - old !c <=
435
         if i = t.indexes.low /\ j = t.indexes.high then 1 else
436 437
         if i = t.indexes.low \/ j = t.indexes.high then 2 * depth t else
          4 * depth t }
438 439 440 441 442 443 444 445 446 447
    ensures { result = sum a i j }
  = c := !c + 1;
    match t with
    | Leaf ind ->
      ind.isum
    | Node ind l r ->
      let k1 = ind.low in
      let k3 = ind.high in
      if i=k1 && j=k3 then ind.isum else
      let m = l.indexes.high in
448
      if j <= m then query_aux_complexity l a i j c else
449
      if i >= m then query_aux_complexity r a i j c else
450
      query_aux_complexity l a i m c + query_aux_complexity r a m j c
451 452
    end

453
end