sumrange.mlw 12.2 KB
Newer Older
1

2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
(** {1 Range Sum Queries}

We are interested in specifying and proving correct
data structures that support efficient computation of the sum of the
values over an arbitrary range of an array.
Concretely, given an array of integers [a], and given a range
delimited by indices [i] (inclusive) and [j] (exclusive), we wish
to compute the value: [\sum_{k=i}^{j-1} a[k]].

In the first part, we consider a simple loop
for computing the sum in linear time.

In the second part, we introduce a cumulative sum array
that allows answering arbitrary range queries in constant time.

In the third part, we explore a tree data structure that
supports modification of values from the underlying array [a],
with logarithmic time operations.

*)


(** {2 Specification of Range Sum} *)
25 26 27 28 29 30

module ArraySum

  use export int.Int
  use export array.Array

31 32 33 34 35 36 37 38 39
  (** [sum a i j] denotes the sum [\sum_{i <= k < j} a[k]].
      It is axiomatizated by the two following axioms expressing
      the recursive definition

      if [i <= j] then [sum a i j = 0]

      if [i < j] then [sum a i j = a[i] + sum a (i+1) j]

  *)
40 41 42 43 44 45 46 47 48
  function sum (array int) int int : int

  axiom sum_def_empty :
    forall a : array int, i j : int. j <= i -> sum a i j = 0

  axiom sum_def_non_empty :
    forall a: array int, i j : int. i < j /\ 0 <= i < a.length ->
      sum a i j = a[i] + sum a (i+1) j

49 50 51 52 53
  (** lemma for summation from the right:

      if [i < j] then [sum a i j = sum a i (j-1) + a[j-1]]

 *)
54 55 56 57 58 59 60 61
  lemma sum_right : forall a : array int, i j : int.
    0 <= i < j <= a.length  ->
    sum a i j = sum a i (j-1) + a[j-1]

end



62 63 64

(** {2 First algorithm, a linear one} *)

65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
module Simple

  use import ArraySum
  use import ref.Ref

  (** [query a i j] returns the sum of elements in [a] between
      index [i] inclusive and index [j] exclusive *)
  let query (a:array int) (i j:int) : int
    requires { 0 <= i <= j <= a.length }
    ensures { result = sum a i j }
  = let s = ref 0 in
    for k=i to j-1 do
      invariant { !s = sum a i k }
      s := !s + a[k]
    done;
    !s

end


85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131


(** {2 Additional lemmas on [sum]}
  needed in the remaining code *)

module ExtraLemmas

  use array.Array
  use import ArraySum

  (** summation in adjacent intervals *)
  lemma sum_concat : forall a:array int, i j k:int.
    0 <= i <= j <= k <= a.length ->
    sum a i k = sum a i j + sum a j k

  (** Frame lemma for [sum], that is [sum a i j] depends only
      of values of [a[i..j-1]] *)
  lemma sum_frame : forall a1 a2 : array int, i j : int.
    0 <= i <= j ->
    j <= a1.length ->
    j <= a2.length ->
    (forall k : int. i <= k < j -> a1[k] = a2[k]) ->
    sum a1 i j = sum a2 i j

  (** Updated lemma for [sum]: how does [sum a i j] changes when
      [a[k]] is changed for some [k] in [[i..j-1]] *)
  lemma sum_update : forall a:array int, i v l h:int.
    0 <= l <= i < h <= a.length ->
    sum (a[i<-v]) l h = sum a l h + v - a[i]


end




(** {2 Algorithm 2: using a cumulative array}

   creation of cumulative array is linear

   query is in constant time

   array update is linear

*)


132 133 134 135
module CumulativeArray

  use array.Array
  use import ArraySum
136
  use ExtraLemmas
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164

  predicate is_cumulative_array_for (c:array int) (a:array int) =
    c.length = a.length + 1 /\
    forall i. 0 <= i < c.length -> c[i] = sum a 0 i

  (** [create a] builds the cumulative array associated with [a]. *)
  let create (a:array int) : array int
    ensures { is_cumulative_array_for result a }
  = let l = a.length in
    let s = Array.make (l+1) 0 in
    for i=1 to l do
      invariant { forall k. 0 <= k < i -> s[k] = sum a 0 k }
      s[i] <- s[i-1] + a[i-1]
    done;
    s

  (** [query c i j a] returns the sum of elements in [a] between
      index [i] inclusive and index [j] exclusive, in constant time *)
  let query (c:array int) (i j:int) (ghost a:array int): int
    requires { is_cumulative_array_for c a }
    requires { 0 <= i <= j < c.length }
    ensures { result = sum a i j }
  = c[j] - c[i]


  (** [update c i v a] updates cell [a[i]] to value [v] and updates
      the cumulative array [c] accordingly *)
  let update (c:array int) (i:int) (v:int) (ghost a:array int) : unit
165 166 167
    requires { is_cumulative_array_for c a }
    requires { 0 <= i < a.length }
    writes  { c, a }
168 169 170 171
    ensures { is_cumulative_array_for c a }
    ensures { a[i] = v }
    ensures { forall k. 0 <= k < a.length /\ k <> i ->
              a[k] = (old a)[k] }
172
  = 'Init:
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
    let incr = v - c[i+1] + c[i] in
    a[i] <- v;
    for j=i+1 to c.length-1 do
      invariant { forall k. j <= k < c.length -> c[k] = sum a 0 k - incr }
      invariant { forall k. 0 <= k < j -> c[k] = sum a 0 k }
      c[j] <- c[j] + incr
    done

end






188 189 190 191 192 193 194 195 196 197 198 199 200
(** {2 Algorithm 3: using a cumulative tree}

  creation is linear

  query is logarithmic

  update is logarithmic

*)




201 202 203 204
module CumulativeTree

  use array.Array
  use import ArraySum
205
  use ExtraLemmas
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
  use import int.ComputerDivision

  type indexes =
    { low : int;
      high : int;
      isum : int;
    }

  type tree = Leaf indexes | Node indexes tree tree

  function indexes (t:tree) : indexes =
    match t with
    | Leaf ind -> ind
    | Node ind _ _ -> ind
    end

  predicate is_indexes_for (ind:indexes) (a:array int) (i j:int) =
    ind.low = i /\ ind.high = j /\
    0 <= i < j <= a.length /\
    ind.isum = sum a i j

  predicate is_tree_for (t:tree) (a:array int) (i j:int) =
    match t with
    | Leaf ind ->
230
        is_indexes_for ind a i j /\ j = i+1
231
    | Node ind l r ->
232
        is_indexes_for ind a i j /\
233 234 235 236 237
        i = l.indexes.low /\ j = r.indexes.high /\
        let m = l.indexes.high in
        m = r.indexes.low /\
        i < m < j /\ m = div (i+j) 2 /\
        is_tree_for l a i m /\
238
        is_tree_for r a m j
239 240
    end

241
  (** {3 creation of cumulative tree} *)
242 243

  let rec tree_of_array (a:array int) (i j:int) : tree
244 245 246
    requires { 0 <= i < j <= a.length }
    variant { j - i }
    ensures { is_tree_for result a i j }
247
    = if i+1=j then begin
248
       Leaf { low = i; high = j; isum = a[i] }
249 250 251 252 253 254 255 256 257 258 259 260 261 262
       end
      else
        begin
        let m = div (i+j) 2 in
        assert { i < m < j };
        let l = tree_of_array a i m in
        let r = tree_of_array a m j in
        let s = l.indexes.isum + r.indexes.isum in
        assert { s = sum a i j };
        Node { low = i; high = j; isum = s} l r
        end


  let create (a:array int) : tree
263
    requires { a.length >= 1 }
264
    ensures { is_tree_for result a 0 a.length }
265 266
  = tree_of_array a 0 a.length

267

268
(** {3 query using cumulative tree} *)
269 270 271 272


  let rec query_aux (t:tree) (ghost a: array int)
      (i j:int) : int
273 274 275 276 277
    requires { is_tree_for t a t.indexes.low t.indexes.high }
    requires { 0 <= t.indexes.low <= i < j <= t.indexes.high <= a.length }
    variant { t }
    ensures { result = sum a i j }
  = match t with
278 279 280 281 282 283 284
    | Leaf ind ->
      ind.isum
    | Node ind l r ->
      let k1 = ind.low in
      let k3 = ind.high in
      if i=k1 && j=k3 then ind.isum else
      let m = l.indexes.high in
285 286 287
      if j <= m then query_aux l a i j else
      if i >= m then query_aux r a i j else
      query_aux l a i m + query_aux r a m j
288 289 290 291 292
    end


  let query (t:tree) (ghost a: array int) (i j:int) : int
    requires { 0 <= i <= j <= a.length }
293
    requires { is_tree_for t a 0 a.length }
294
    ensures { result = sum a i j }
295
  = if i=j then 0 else query_aux t a i j
296 297


298 299 300 301 302 303
  (** frame lemma for predicate [is_tree_for] *)
  lemma is_tree_for_frame : forall t:tree, a:array int, k v i j:int.
    0 <= k < a.length ->
    k < i \/ k >= j ->
    is_tree_for t a i j ->
    is_tree_for t a[k<-v] i j
304

305
(** {3 update cumulative tree} *)
306 307 308


  let rec update_aux (t:tree) (i:int) (ghost a :array int) (v:int) : (tree,int)
309 310 311
    requires { is_tree_for t a t.indexes.low t.indexes.high }
    requires { t.indexes.low <= i < t.indexes.high }
    variant { t }
312
    returns { (t',delta) ->
313
        delta = v - a[i] /\
314 315
        t'.indexes.low = t.indexes.low /\
        t'.indexes.high = t.indexes.high /\
316 317
        is_tree_for t' a[i<-v] t'.indexes.low t'.indexes.high }
  = match t with
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
    | Leaf ind ->
        assert { i = ind.low };
        (Leaf { ind with isum = v }, v - ind.isum)
    | Node ind l r ->
        let m = l.indexes.high in
      if i < m then
        let l',delta = update_aux l i a v in
        assert { is_tree_for l' a[i<-v] t.indexes.low m };
        assert { is_tree_for r a[i<-v] m t.indexes.high };
        (Node {ind with isum = ind.isum + delta } l' r, delta)
      else
        let r',delta = update_aux r i a v in
        assert { is_tree_for l a[i<-v] t.indexes.low m };
        assert { is_tree_for r' a[i<-v] m t.indexes.high };
        (Node {ind with isum = ind.isum + delta} l r',delta)
333
    end
334 335

  let update (t:tree) (ghost a:array int) (i v:int) : tree
336 337
     requires { 0 <= i < a.length }
     requires { is_tree_for t a 0 a.length }
338 339 340 341
     writes { a }
     ensures { a[i] = v }
     ensures { forall k. 0 <= k < a.length /\ k <> i -> a[k] = (old a)[k] }
     ensures { is_tree_for result a 0 a.length }
342
  = let t,_ = update_aux t i a v in
343 344 345 346
    a[i] <- v;
    t


347 348 349 350 351 352 353 354 355 356 357 358
(** {2 complexity analysis}

  We would like to prove that [query] is really logarithmic. This is
  non-trivial because there are two recursive calls in some cases.

  So far, we are only able to prove that [update] is logarithmic

  We express the complexity by passing a ``credit'' as a ghost
  parameter. We pose the precondition that the credit is at least
  equal to the depth of the tree.

*)
359

360 361 362 363
  (** preliminaries: definition of the depth of a tree, and showing
      that it is indeed logarithmic in function of the number of its
      elements *)

364 365 366 367
  use import int.MinMax

  function depth (t:tree) : int =
    match t with
368
    | Leaf _ -> 1
369 370 371
    | Node _ l r -> 1 + max (depth l) (depth r)
    end

372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391
  lemma depth_min : forall t. depth t >= 1
  
  use import bv.Pow2int

  let rec lemma depth_is_log (t:tree) (a :array int) (k:int)
     requires { k >= 0 }
     requires { is_tree_for t a t.indexes.low t.indexes.high }
     requires { t.indexes.high - t.indexes.low <= pow2 k }
     variant { t }
     ensures { depth t <= k+1 }
  = match t with
    | Leaf _ -> ()
    | Node _ l r ->
       depth_is_log l a (k-1);
       depth_is_log r a (k-1)
    end


  (** [update_aux] function instrumented with a credit *)

392 393
  use import ref.Ref

394
  let rec update_aux_complexity (t:tree) (i:int) (ghost a :array int) (v:int) (ghost c:ref int): (tree,int)
395 396 397
     requires { is_tree_for t a t.indexes.low t.indexes.high }
     requires { t.indexes.low <= i < t.indexes.high }
     variant { t }
398
     ensures { !c - old !c <= depth t }
399
     returns { (t',delta) ->
400
        delta = v - a[i] /\
401 402
        t'.indexes.low = t.indexes.low /\
        t'.indexes.high = t.indexes.high /\
403
        is_tree_for t' a[i<-v] t'.indexes.low t'.indexes.high }
404
  = c := !c + 1;
405 406
    match t with
    | Leaf ind ->
407 408
      assert { i = ind.low };
      (Leaf { ind with isum = v }, v - ind.isum)
409
    | Node ind l r ->
410
      let m = l.indexes.high in
411
      if i < m then
412
        let l',delta = update_aux_complexity l i a v c in
413 414 415 416
        assert { is_tree_for l' a[i<-v] t.indexes.low m };
        assert { is_tree_for r a[i<-v] m t.indexes.high };
        (Node {ind with isum = ind.isum + delta } l' r, delta)
      else
417
        let r',delta = update_aux_complexity r i a v c in
418 419 420 421 422
        assert { is_tree_for l a[i<-v] t.indexes.low m };
        assert { is_tree_for r' a[i<-v] m t.indexes.high };
        (Node {ind with isum = ind.isum + delta} l r',delta) (*>*)
    end

423
  (** [query_aux] function instrumented with a credit *)
424

425 426 427 428 429 430 431
  let rec query_aux_complexity (t:tree) (ghost a: array int)
      (i j:int) (ghost c:ref int) : int
    requires { is_tree_for t a t.indexes.low t.indexes.high }
    requires { 0 <= t.indexes.low <= i < j <= t.indexes.high <= a.length }
    variant { t }
    ensures { !c - old !c <= 
         if i = t.indexes.low /\ j = t.indexes.high then 1 else
432 433
         if i = t.indexes.low \/ j = t.indexes.high then 2 * depth t else
          4 * depth t }
434 435 436 437 438 439 440 441 442 443
    ensures { result = sum a i j }
  = c := !c + 1;
    match t with
    | Leaf ind ->
      ind.isum
    | Node ind l r ->
      let k1 = ind.low in
      let k3 = ind.high in
      if i=k1 && j=k3 then ind.isum else
      let m = l.indexes.high in
444
      if j <= m then query_aux_complexity l a i j c else
445
      if i >= m then query_aux_complexity r a i j c else
446
      query_aux_complexity l a i m c + query_aux_complexity r a m j c
447 448
    end

449 450

end