lineardecision.mlw 48.4 KB
Newer Older
1 2
module LinearEquationsCoeffs

3 4 5 6
type a
function (+) a a : a
function ( *) a a : a
function (-_) a : a
7 8 9
function azero: a
function aone: a
predicate ale a a
10

11
clone algebra.OrderedUnitaryCommutativeRing as A with type t = a, function (+) = (+), function ( *) = ( *), function (-_) = (-_), constant zero = azero, constant one=aone, predicate (<=) = ale
12

13 14 15 16
function (-) a a : a

axiom sub_def: forall a1 a2. a1 - a2 = a1 + (- a2)

17
type t
18
type vars = int -> a
19
type cvars
20 21
exception Unknown

22
function interp t cvars : a
23

24 25
val constant czero : t
val constant cone : t
26

27 28
axiom zero_def: forall y. interp czero y = azero
axiom one_def: forall y. interp cone y = aone
29

30 31

val add (a b: t) : t
32
  ensures { forall v: cvars. interp result v = interp a v + interp b v }
33 34 35
  raises  { Unknown -> true }

val mul (a b: t) : t
36
  ensures { forall v: cvars. interp result v = interp a v * interp b v }
37
  raises  { Unknown -> true }
38 39

val opp (a:t) : t
40
  ensures { forall v: cvars. interp result v = - (interp a v) }
41 42

val predicate eq (a b:t)
43
  ensures { result -> forall y:cvars. interp a y = interp b y }
44 45

val inv (a:t) : t
46
  requires { not (eq a czero) }
47
 (* ensures { forall v: cvars. interp result v * interp a v = aone } no proof needed, but had better be true *)
48 49
  ensures { not (eq result czero) }
  raises { Unknown -> true }
50 51 52 53 54 55 56 57 58 59

end

module LinearEquationsDecision

use import int.Int

type coeff

clone LinearEquationsCoeffs as C with type t = coeff
60
type vars = C.vars
61

62
type expr = Term coeff int | Add expr expr | Cst coeff
63 64 65 66

let rec predicate valid_expr (e:expr)
  variant { e }
= match e with
67
  | Term _ i -> 0 <= i
68 69 70 71 72 73 74
  | Cst _ -> true
  | Add e1 e2 -> valid_expr e1 && valid_expr e2
  end

let rec predicate expr_bound (e:expr) (b:int)
  variant { e }
= match e with
75
  | Term _ i -> 0 <= i <= b
76 77 78 79
  | Cst _ -> true
  | Add e1 e2 -> expr_bound e1 b && expr_bound e2 b
  end

80
function interp (e:expr) (y:vars) (z:C.cvars) : C.a
81
= match e with
82 83
  | Term c v -> C.( *) (C.interp c z) (y v)
  | Add e1 e2 -> C.(+) (interp e1 y z) (interp e2 y z)
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
  | Cst c -> C.interp c z
  end

use import bool.Bool
use import list.List

type equality = (expr, expr)
type context = list equality

let predicate valid_eq (eq:equality)
= match eq with (e1,e2) -> valid_expr e1 && valid_expr e2 end

let predicate eq_bound (eq:equality) (b:int)
= match eq with (e1,e2) -> expr_bound e1 b && expr_bound e2 b end

let rec predicate valid_ctx (ctx:context)
= match ctx with Nil -> true | Cons eq t -> valid_eq eq && valid_ctx t end

let rec predicate ctx_bound (ctx:context) (b:int)
= match ctx with Nil -> true | Cons eq t -> eq_bound eq b && ctx_bound t b end

let rec lemma expr_bound_w (e:expr) (b1 b2:int)
  requires { b1 <= b2 }
  requires { expr_bound e b1 }
  ensures  { expr_bound e b2 }
  variant  { e }
= match e with
  | Add e1 e2 -> expr_bound_w e1 b1 b2; expr_bound_w e2 b1 b2
  | Cst _ -> ()
  | Term _ _ -> ()
  end

lemma eq_bound_w: forall e:equality, b1 b2:int. eq_bound e b1 -> b1 <= b2 -> eq_bound e b2

let rec lemma ctx_bound_w (l:context) (b1 b2:int)
  requires { ctx_bound l b1 }
  requires { b1 <= b2 }
  ensures  { ctx_bound l b2 }
  variant  { l }
= match l with Nil -> () | Cons _ t -> ctx_bound_w t b1 b2 end

125
function interp_eq (g:equality) (y:vars) (z:C.cvars) : bool
126 127
  = match g with (g1, g2) -> interp g1 y z = interp g2 y z end

128
function interp_ctx (l: context) (g: equality) (y: vars) (z:C.cvars) : bool
129 130
= match l with
  | Nil -> interp_eq g y z
131
  | Cons h t -> (interp_eq h y z) -> (interp_ctx t g y z)
132 133 134 135 136 137 138 139 140
  end

use import array.Array
use import matrix.Matrix

let apply_r (m: matrix coeff) (v: array coeff) : array coeff
  requires { v.length = m.columns }
  ensures  { result.length = m.rows }
  raises   { C.Unknown -> true }
141
= let r = Array.make m.rows C.czero in
142 143 144 145 146 147 148 149 150 151 152
  for i = 0 to m.rows - 1 do
    for j = 0 to m.columns - 1 do
      r[i] <- C.add r[i] (C.mul (get m i j) v[j]);
    done
  done;
  r

let apply_l (v: array coeff) (m: matrix coeff) : array coeff
  requires { v.length = m.rows }
  ensures  { result.length = m.columns }
  raises   { C.Unknown -> true }
153
= let r = Array.make m.columns C.czero in
154 155 156 157 158 159 160 161 162 163 164 165
  for j = 0 to m.columns - 1 do
    for i = 0 to m.rows - 1 do
      r[j] <- C.add r[j] (C.mul (get m i j) v[i]);
    done
  done;
  r

use import ref.Ref

let sprod (a b: array coeff) : coeff
  requires { a.length = b.length }
  raises   { C.Unknown -> true }
166
= let r = ref C.czero in
167 168 169 170 171 172 173 174 175 176 177 178
  for i = 0 to a.length - 1 do
    r := C.add !r (C.mul a[i] b[i]);
  done;
  !r

let m_append (m: matrix coeff) (v:array coeff) : matrix coeff
  requires { m.rows = v.length }
  ensures  { result.rows = m.rows }
  ensures  { result.columns = m.columns + 1 }
  ensures  { forall i j. 0 <= i < m.rows -> 0 <= j < m.columns ->
             result.elts i j = m.elts i j }
  ensures  { forall i. 0 <= i < m.rows -> result.elts i m.columns = v[i] }
179
= let r = Matrix.make m.rows (m.columns + 1) C.czero in
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
  for i = 0 to m.rows - 1 do
    invariant { forall k j. 0 <= k < i -> 0 <= j < m.columns ->
                r.elts k j = m.elts k j }
    invariant { forall k. 0 <= k < i -> r.elts k m.columns = v[k] }
    for j = 0 to m.columns - 1 do
      invariant { forall k j. 0 <= k < i -> 0 <= j < m.columns ->
                r.elts k j = m.elts k j }
      invariant { forall k. 0 <= k < i -> r.elts k m.columns = v[k] }
      invariant { forall l. 0 <= l < j -> r.elts i l = m.elts i l }
      set r i j (get m i j)
    done;
    set r i m.columns v[i]
  done;
  r

let v_append (v: array coeff) (c: coeff) : array coeff
  ensures { length result = length v + 1 }
  ensures { forall k. 0 <= k < v.length -> result[k] = v[k] }
  ensures { result[v.length] = c }
= let r = Array.make (v.length + 1) c in
  for i = 0 to v.length - 1 do
    invariant { forall k. 0 <= k < i -> r[k] = v[k] }
    invariant { r[v.length] = c }
    r[i] <- v[i]
  done;
  r

207
let predicate (==) (a b: array coeff)
208 209
  ensures { result = true -> length a = length b /\
            forall i. 0 <= i < length a -> C.eq a[i] b[i] }
210 211 212 213 214 215 216 217 218
=
  if length a <> length b then false
  else
    let r = ref true in
    for i = 0 to length a - 1 do
      invariant { !r = true -> forall j. 0 <= j < i -> C.eq a[j] b[j] }
      if not (C.eq a[i] b[i]) then r := false;
    done;
    !r
219 220 221 222 223 224 225 226 227 228

use import int.MinMax
use import list.Length

let rec function max_var (e:expr) : int
  variant { e }
  requires { valid_expr e }
  ensures { 0 <= result }
  ensures { expr_bound e result }
= match e with
229
  | Term _ i -> i
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
  | Cst _ -> 0
  | Add e1 e2 -> max (max_var e1) (max_var e2)
  end

let function max_var_e (e:equality) : int
  requires { valid_eq e }
  ensures { 0 <= result }
  ensures { eq_bound e result }
= match e with (e1,e2) -> max (max_var e1) (max_var e2) end

let rec function max_var_ctx (l:context) : int
  variant { l }
  requires { valid_ctx l }
  ensures { 0 <= result }
  ensures { ctx_bound l result }
= match l with
  | Nil -> 0
  | Cons e t -> max (max_var_e e) (max_var_ctx t)
  end

250
let rec opp_expr (e:expr) : expr
251
  ensures { forall y z. interp result y z = C.(-_) (interp e y z) }
252
  ensures { valid_expr e -> valid_expr result }
253 254 255 256
  ensures { forall b. expr_bound e b -> expr_bound result b }
  variant { e }
= match e with
  | Cst c -> Cst (C.opp c)
257 258 259 260 261 262 263 264
  | Term c j ->
    let oc = C.opp c in
    let r = Term oc j in
    assert { forall y z. interp r y z = C.( *) (C.interp oc z) (y j)
             = C.( *) (C.(-_) (C.interp c z)) (y j)
             = C.(-_) (C.( *) (C.interp c z) (y j))
             = C.(-_) (interp e y z) };
    r
265 266 267 268 269 270 271 272 273 274
  | Add e1 e2 ->
      let e1' = opp_expr e1 in
      let e2' = opp_expr e2 in
      assert { forall a1 a2. C.(+) (C.(-_) a1) (C.(-_) a2) = C.(-_) (C.(+) a1 a2) };
      assert { forall y z. interp (Add e1' e2') y z = C.(-_) (interp e y z) by
               interp (Add e1' e2') y z = C.(+) (interp e1' y z) (interp e2' y z)
               = C.(+) (C.(-_) (interp e1 y z)) (C.(-_) (interp e2 y z))
               = C.(-_) (C.(+) (interp e1 y z) (interp e2 y z))
               = C.(-_) (interp e y z) };
      Add e1' e2'
275 276
  end

277 278 279 280 281
predicate atom (e:expr)
= match e with
  | Add _ _ -> false | _ -> true
  end

282 283 284
(*TODO put this back in norm_eq*)
let rec norm_eq_aux (ex acc_e:expr) (acc_c:coeff) : (expr, coeff)
  returns { (rex, rc) -> forall y z.
285 286 287
              C.(+) (interp rex y z) (interp (Cst rc) y z)
            = C.(+) (interp ex y z)
                    (C.(+) (interp acc_e y z) (interp (Cst acc_c) y z)) }
288 289 290 291 292 293
  returns { (rex, _) -> forall b:int. expr_bound ex b /\ expr_bound acc_e b
                        -> expr_bound rex b }
  raises  { C.Unknown -> true }
  variant { ex }
= match ex with
  | Cst c -> acc_e, (C.add c acc_c)
294
  | Term _ _ -> (Add acc_e ex, acc_c)
295 296 297 298
  | Add e1 e2 -> let ae, ac = norm_eq_aux e1 acc_e acc_c in
                 norm_eq_aux e2 ae ac
  end

299
let norm_eq (e:equality) : (expr, coeff)
300
  returns { (ex, c) -> forall y z.
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
301
            interp_eq e y z <-> interp_eq (ex, Cst c) y z }
302 303 304 305 306 307
  returns { (ex, _) -> forall b:int. eq_bound e b -> expr_bound ex b }
  raises  { C.Unknown -> true }
= match e with
  | (e1, e2) ->
    let s = Add e1 (opp_expr e2) in
    assert { forall b. eq_bound e b -> expr_bound s b };
308
    match norm_eq_aux s (Cst C.czero) C.czero with
309 310 311 312 313 314 315 316 317 318 319 320
      (e, c) ->
        let ec = C.opp c in
        assert { forall a1 a2. C.(+) a1 a2 = C.azero -> a1 = C.(-_) a2 };
        assert { forall y z. interp_eq (e1,e2) y z -> interp_eq (e, Cst ec) y z
                 by interp_eq (s, Cst C.czero) y z so interp s y z = C.azero
                 so C.(+) (interp e y z) (interp (Cst c) y z) = C.azero
                 so interp e y z = C.(-_) (interp (Cst c) y z)
                    = interp (Cst ec) y z };
        e, ec
    end
  end

321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339

let rec lemma interp_ctx_impl (ctx: context) (g1 g2:equality)
  requires { forall y z. interp_eq g1 y z -> interp_eq g2 y z }
  ensures  { forall y z. interp_ctx ctx g1 y z -> interp_ctx ctx g2 y z }
  variant  { ctx }
= match ctx with Nil -> () | Cons _ t -> interp_ctx_impl t g1 g2 end

let rec lemma interp_ctx_valid (ctx:context) (g:equality)
  ensures { forall y z. interp_eq g y z -> interp_ctx ctx g y z }
  variant  { ctx }
= match ctx with Nil -> () | Cons _ t -> interp_ctx_valid t g end

use import list.Append

let rec lemma interp_ctx_wr (ctx l:context) (g:equality)
  ensures { forall y z. interp_ctx ctx g y z -> interp_ctx (ctx ++ l) g y z }
  variant { ctx }
= match ctx with
  | Nil -> ()
340
  | Cons _ t -> interp_ctx_wr t l g  end
341 342 343 344

let rec lemma interp_ctx_wl (ctx l: context) (g:equality)
  ensures { forall y z. interp_ctx ctx g y z -> interp_ctx (l ++ ctx) g y z }
  variant { l }
345
= match l with Nil -> () | Cons _ t -> interp_ctx_wl ctx t g  end
346

347 348 349
let rec mul_expr (e:expr) (c:coeff) : expr
  ensures { forall y z. interp result y z
            = C.( *) (C.interp c z) (interp e y z) }
350
  ensures { valid_expr e -> valid_expr result }
351
  variant { e }
352
  raises  { C.Unknown -> true }
353 354
= if C.eq c C.czero then Cst C.czero
  else match e with
355 356 357 358 359 360 361 362 363 364
  | Cst c1 -> Cst (C.mul c c1)
  | Term c1 v -> Term (C.mul c c1) v
  | Add e1 e2 -> Add (mul_expr e1 c) (mul_expr e2 c)
  end

let rec add_expr (e1 e2: expr) : expr
  ensures { forall y z. interp result y z
                     = C.(+) (interp e1 y z) (interp e2 y z) }
  variant { e2 }
  raises  { C.Unknown -> true }
365 366 367 368 369
=
  let term_or_cst c i
    ensures { forall y z. interp result y z = interp (Term c i) y z }
  = if C.eq C.czero c then Cst C.czero else Term c i in
  let rec add_atom (e a:expr) : (expr, bool)
370 371 372 373 374 375 376
    requires { atom a }
    returns { r,b -> forall y z. interp r y z
                     = C.(+) (interp e y z) (interp a y z) }
    variant { e }
    raises  { C.Unknown -> true }
  = match (e,a) with
    | Term ce ie, Term ca ia ->
377 378 379 380
      if ie = ia then (term_or_cst (C.add ce ca) ie, True)
      else if C.eq ce C.czero then (term_or_cst ca ia, True)
      else if C.eq ca C.czero then (e,True)
      else (Add e a, False)
381
    | Cst ce, Cst ca -> Cst (C.add ce ca), True
382 383 384 385 386 387 388 389
    | Cst ce, Term ca _ ->
      if C.eq ca C.czero then (e, True)
      else if C.eq ce C.czero then (a, True)
      else (Add e a, False)
    | Term ce _, Cst ca ->
      if C.eq ce C.czero then (a, True)
      else if C.eq ca C.czero then (e, True)
      else (Add e a, False)
390 391
    | Add e1 e2, _ ->
      let r, b = add_atom e1 a in
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413
      if b
      then
        match r with
          | Cst c ->
            if C.eq c C.czero
            then begin
              assert { forall y z. C.(+) (interp e1 y z) (interp a y z) = C.azero };
              e2, True end
            else Add r e2, True
          | _ -> Add r e2, True
        end
      else
        let r,b = add_atom e2 a in
        match r with
          | Cst c ->
            if C.eq c C.czero
            then begin
              assert { forall y z. C.(+) (interp e2 y z) (interp a y z) = C.azero };
              e1, True end
            else Add e1 r, b
          | _ -> Add e1 r, b
        end
414
    | _, Add _ _ -> absurd
415
    end
416 417 418 419
  in
  match e2 with
    | Add e1' e2' -> add_expr (add_expr e1 e1') e2'
    | _ -> let r,_= add_atom e1 e2 in r
420 421
  end

422 423
let mul_eq (eq:equality) (c:coeff)
  ensures { forall y z. interp_eq eq y z -> interp_eq result y z }
424
  raises  { C.Unknown -> true }
425 426 427 428 429
= match eq with (e1,e2) -> (mul_expr e1 c, mul_expr e2 c) end

let add_eq (eq1 eq2:equality)
  ensures { forall y z. interp_eq eq1 y z -> interp_eq eq2 y z
            -> interp_eq result y z }
430 431
  ensures { forall y z ctx. interp_ctx ctx eq1 y z -> interp_ctx ctx eq2 y z
            -> interp_ctx ctx result y z }
432
  raises  { C.Unknown -> true }
433 434 435 436 437 438 439 440 441 442
= match eq1, eq2 with ((a1,b1), (a2,b2)) ->
  let a = add_expr a1 a2 in let b =  add_expr b1 b2 in
  let r = (a,b) in
  let rec lemma aux (l:context)
    ensures { forall y z. interp_ctx l eq1 y z -> interp_ctx l eq2 y z
              -> interp_ctx l r y z }
    variant { l }
  = match l with Nil -> () | Cons _ t -> aux t end in
  r
  end
443

444
let rec zero_expr (e:expr) : bool
445 446
  ensures { result -> forall y z. interp e y z = C.azero }
  variant { e }
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
447 448
  raises  { C.Unknown -> true }
=
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
449 450
  let rec all_zero (e:expr) : bool
    ensures { result -> forall y z. interp e y z = C.azero }
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
451 452
    variant { e }
    = match e with
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
453 454 455 456
    | Cst c -> C.eq c C.czero
    | Term c _ -> C.eq c C.czero
    | Add e1 e2 -> all_zero e1 && all_zero e2
    end
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
457
  in
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
458 459
  let e' = add_expr (Cst C.czero) e in (* simplifies expr *)
  all_zero e'
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475

let sub_expr (e1 e2:expr)
  ensures { forall y z. C.(+) (interp result y z) (interp e2 y z)
                        = interp e1 y z }
  raises  { C.Unknown -> true }
= let r = add_expr e1 (mul_expr e2 (C.opp C.cone)) in
  assert { forall y z.
           let v1 = interp e1 y z in
           let v2 = interp e2 y z in
           let vr = interp r y z in
           C.(+) vr v2 = v1
           by C.( *) v2 (C.(-_) C.aone) = C.(-_) v2
           so C.(+) vr v2
           = C.(+) (C.(+) v1 (C.( *) v2 (C.(-_) C.aone))) v2
           = C.(+) (C.(+) v1 (C.(-_) v2)) v2 = v1 };
  r
476

477 478
use import debug.Debug

479 480 481
let rec same_eq (eq1 eq2: equality) : bool
  ensures { result -> forall y z. interp_eq eq1 y z -> interp_eq eq2 y z }
  raises  { C.Unknown -> true }
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
482 483 484
= let (e1,c1) = norm_eq eq1 in
  let (e2,c2) = norm_eq eq2 in
  let e = sub_expr e1 e2 in
485 486
  if zero_expr e && C.eq c1 c2 then true
  else (print (add_expr (Cst C.czero) e); print c1; print c2; false)
487 488 489

use import option.Option

490 491 492 493 494 495 496 497 498 499
let rec norm_context (l:context) : context
  ensures { forall g y z. interp_ctx result g y z -> interp_ctx l g y z }
  raises  { C.Unknown -> true }
  variant { l }
= match l with
  | Nil -> Nil
  | Cons h t ->
    let ex, c = norm_eq h in
    Cons (ex, Cst c) (norm_context t)
  end
500

501 502 503 504
let rec print_lc ctx v : unit variant { ctx }
= match ctx, v with
  | Nil, Nil -> ()
  | Cons l t, Cons v t2 ->
505
   (if C.eq C.czero v then ()
506 507 508 509 510
    else (print l; print v));
    print_lc t t2
  | _ -> ()
  end

511 512 513 514
let check_combination (ctx:context) (g:equality) (v:list coeff) : bool
  ensures  { result = true -> forall y z. interp_ctx ctx g y z}
  raises  { C.Unknown -> true }
=
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
515 516 517
  (*let ctx = norm_context ctx in
  let (g,c) = norm_eq g in*)
  (* normalize before for fewer Unknown exceptions in computations ? *)
518 519 520
  let rec aux (l:context) (ghost acc: context) (s:equality) (v:list coeff) : option equality
    requires { forall y z. interp_ctx acc s y z }
    requires { ctx = acc ++ l }
521 522
    returns  { Some r -> forall y z. interp_ctx ctx r y z | None -> true }
    raises  { C.Unknown -> true }
523 524
    variant { l }
  = match (l, v) with
525 526
    | Nil, Nil -> Some s
    | Cons eq te, Cons c tc ->
527
      let ghost nacc = acc ++ (Cons eq Nil) in
528 529 530 531 532 533 534 535
      if C.eq c C.czero then aux te nacc s tc
      else begin
        let ns = (add_eq s (mul_eq eq c)) in
        interp_ctx_wr ctx (Cons eq Nil) s;
        interp_ctx_wl ctx (Cons eq Nil) eq;
        assert { forall y z. interp_ctx nacc ns y z
                 by interp_ctx nacc s y z /\ interp_ctx nacc eq y z };
        aux te nacc ns tc end
536 537 538
    | _ -> None
    end
  in
539
  match aux ctx Nil (Cst C.czero, Cst C.czero) v with
540
  | Some sum -> if same_eq sum g then true else (print_lc ctx v; false)
541 542
  | None -> false
  end
543

544 545 546
let transpose (m:matrix coeff) : matrix coeff
  ensures { result.rows = m.columns /\ result.columns = m.rows }
=
547
  let r = Matrix.make m.columns m.rows C.czero in
548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564
  for i = 0 to m.rows - 1 do
    for j = 0 to m.columns - 1 do
      set r j i (get m i j)
    done
  done;
  r

let swap_rows (m:matrix coeff) (i1 i2: int) : unit
  requires { 0 <= i1 < m.rows /\ 0 <= i2 < m.rows }
= for j = 0 to m.columns - 1 do
    let c = get m i1 j in
    set m i1 j (get m i2 j);
    set m i2 j c
  done

let mul_row (m:matrix coeff) (i: int) (c: coeff) : unit
  requires { 0 <= i < m.rows }
565
  requires { not (C.eq c C.czero) }
566
  raises  { C.Unknown -> true }
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
567 568
= if C.eq c C.cone then () else
  for j = 0 to m.columns - 1 do
569 570 571 572 573 574
    set m i j (C.mul c (get m i j))
  done

let addmul_row (m:matrix coeff) (src dst: int) (c: coeff) : unit
  requires { 0 <= src < m.rows /\ 0 <= dst < m.rows }
  raises   { C.Unknown -> true }
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
575 576
= if C.eq c C.czero then () else
  for j = 0 to m.columns - 1 do
577 578 579 580 581 582 583
    set m dst j (C.add (get m dst j) (C.mul c (get m src j)))
  done

use import ref.Refint

let gauss_jordan (a: matrix coeff) : option (array coeff)
  (*AX=B, a=(A|B), result=X*)
584
  returns { Some r -> Array.length r = a.columns | None -> true }
585 586 587 588 589
  requires { 1 <= a.rows /\ 1 <= a.columns }
  raises { C.Unknown -> true }
=
  let n = a.rows in
  let m = a.columns in
590
  (* print n; print m; *)
591 592 593 594 595 596 597 598 599 600 601
  let rec find_nonz (i j:int)
    requires { 0 <= i <= n }
    requires { 0 <= j < m }
    variant { n-i }
    ensures { i <= result <= n }
    ensures { result < n -> not (C.eq (a.elts result j) C.czero) }
    = if i >= n then n
    else
      if C.eq (get a i j) C.czero
      then find_nonz (i+1) j
      else i in
602 603
  let pivots = Array.make n 0 in
  let r = ref (-1) in
604
  for j = 0 to m-2 do
605 606 607 608 609
    invariant { -1 <= !r < n }
    invariant { forall i. 0 <= i <= !r -> 0 <= pivots[i] }
    invariant { forall i1 i2: int. 0 <= i1 < i2 <= !r -> pivots[i1] < pivots[i2] }
    invariant { !r >= 0 -> pivots[!r] < j }
    label Start in
610
    let k = find_nonz (!r+1) j in
611 612 613 614 615 616 617 618 619 620 621 622 623 624
    if k < n
    then begin
      incr r;
      pivots[!r] <- j;
      mul_row a k (C.inv(get a k j));
      if k <> !r then swap_rows a k !r;
      for i = 0 to n-1 do
        if i <> !r
        then addmul_row a !r i (C.opp(get a i j))
      done;
    end
  done;
  if !r < 0 then None (* matrix is all zeroes *)
  else begin
625
    let v = Array.make m(*(m-1)*) C.czero in
626 627 628
    for i = 0 to !r do
      v[pivots[i]] <- get a i (m-1)
    done;
629
    Some v (*pivots[!r] < m-1*)  (*pivot on last column, no solution*)
630 631
  end

632
use import array.ToList
633

634 635
exception Absurd

636 637 638
let linear_decision (l: context) (g: equality) : bool
  requires { valid_ctx l }
  requires { valid_eq g }
639
  ensures { forall y z. result -> interp_ctx l g y z }
640
  raises  { C.Unknown -> true | Absurd -> true }
641 642
=
  let nv = max (max_var_e g) (max_var_ctx l) in
643 644
  let ll = length l in
  let a = Matrix.make ll (nv+1) C.czero in
645
  let b = Array.make ll C.czero in            (* ax = b *)
646
  let v = Array.make (nv+1) C.czero in          (* goal *)
647 648 649 650 651
  let rec fill_expr (ex: expr) (i:int): unit
    variant { ex }
    raises  { C.Unknown -> true }
    requires { 0 <= i < length l }
    requires { expr_bound ex nv }
652
    raises  { Absurd -> true }
653
  = match ex with
654
    | Cst c -> if C.eq c C.czero then () else raise Absurd
655 656 657 658 659 660 661 662 663
    | Term c j -> set a i j (C.add (get a i j) c)
    | Add e1 e2 -> fill_expr e1 i; fill_expr e2 i
    end in
  let rec fill_ctx (ctx:context) (i:int) : unit
    requires { ctx_bound ctx nv }
    variant { length l - i }
    requires { length l - i = length ctx }
    requires { 0 <= i <= length l }
    raises  { C.Unknown -> true }
664
    raises  { Absurd -> true }
665 666 667 668 669 670 671 672 673 674 675 676 677
  = match ctx with
    | Nil -> ()
    | Cons e t ->
      assert { i < length l };
      let ex, c = norm_eq e in
      if (not (C.eq c C.czero)) then b[i] <- C.add b[i] c;
      fill_expr ex i;
      fill_ctx t (i+1)
    end in
  let rec fill_goal (ex:expr) : unit
    requires { expr_bound ex nv }
    variant { ex }
    raises { C.Unknown -> true }
678
    raises  { Absurd -> true }
679
  = match ex with
680
    | Cst c -> if C.eq c C.czero then () else raise Absurd
681 682
    | Term c j -> v[j] <- C.add v[j] c
    | Add e1 e2 -> fill_goal e1; fill_goal e2
683
    end in
684
  fill_ctx l 0;
685
  let (ex, d) = norm_eq g in
686
  fill_goal ex;
687 688 689 690
  let ab = m_append a b in
  let cd = v_append v d in
  let ab' = transpose ab in
  match gauss_jordan (m_append ab' cd) with
691
    | Some r ->
692
      check_combination l g (to_list r 0 ll)
693 694 695
    | None -> false
  end

696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796
type expr' = | Sum expr' expr' | ProdL expr' cprod | ProdR cprod expr' | Diff expr' expr'
             | Var int | Coeff coeff

with cprod = | C coeff | Times cprod cprod

function interp_c (e:cprod) (y:vars) (z:C.cvars) : C.a
= match e with
  | C c -> C.interp c z
  | Times e1 e2 -> C.(*) (interp_c e1 y z) (interp_c e2 y z)
  end

function interp' (e:expr') (y:vars) (z:C.cvars) : C.a
= match e with
  | Sum e1 e2 -> C.(+) (interp' e1 y z) (interp' e2 y z)
  | ProdL e c -> C.(*) (interp' e y z) (interp_c c y z)
  | ProdR c e -> C.(*) (interp_c c y z) (interp' e y z)
  | Diff e1 e2 -> C.(-) (interp' e1 y z) (interp' e2 y z)
  | Var n -> y n
  | Coeff c -> C.interp c z
  end

(*exception NonLinear*)

type equality' = (expr', expr')
type context' = list equality'

function interp_eq' (g:equality') (y:vars) (z:C.cvars) : bool
= match g with (g1, g2) -> interp' g1 y z = interp' g2 y z end

function interp_ctx' (l: context') (g: equality') (y: vars) (z:C.cvars) : bool
= match l with
  | Nil -> interp_eq' g y z
  | Cons h t -> (interp_eq' h y z) -> (interp_ctx' t g y z)
  end

let rec predicate valid_expr' (e:expr')
  variant { e }
= match e with
  | Var i -> 0 <= i
  | Sum e1 e2 | Diff e1 e2 -> valid_expr' e1 && valid_expr' e2
  | Coeff _ -> true
  | ProdL e _ | ProdR _ e -> valid_expr' e
  end

let predicate valid_eq' (eq:equality')
= match eq with (e1,e2) -> valid_expr' e1 && valid_expr' e2 end

let rec predicate valid_ctx' (ctx:context')
= match ctx with Nil -> true | Cons eq t -> valid_eq' eq && valid_ctx' t end

let rec simp (e:expr') : expr
  ensures { forall y z. interp result y z = interp' e y z }
  ensures { valid_expr' e -> valid_expr result }
  raises  { C.Unknown -> true }
  variant { e }
=
  let rec simp_c (e:cprod) : coeff
    ensures { forall y z. C.interp result z = interp_c e y z }
    variant { e }
    raises  { C.Unknown -> true }
  =
    match e with
    | C c -> c
    | Times c1 c2 -> C.mul (simp_c c1) (simp_c c2)
    end
  in
  match e with
  | Sum e1 e2 -> Add (simp e1) (simp e2)
  | Diff e1 e2 -> Add (simp e1) (opp_expr (simp e2))
  | Var n -> Term C.cone n
  | Coeff c -> Cst c
  | ProdL e c | ProdR c e ->
    mul_expr (simp e) (simp_c c)
  end

let simp_eq (eq:equality') : equality
  ensures { forall y z. interp_eq result y z = interp_eq' eq y z }
  ensures { valid_eq' eq -> valid_eq result }
  raises  { (*NonLinear -> true | *)C.Unknown -> true }
= match eq with (g1, g2) -> (simp g1, simp g2) end

let rec simp_ctx (ctx: context') (g:equality') : (context, equality)
  returns { (rc, rg) ->
            (valid_ctx' ctx -> valid_eq' g -> valid_ctx rc /\ valid_eq rg) /\
            forall y z. interp_ctx rc rg y z = interp_ctx' ctx g y z }
  raises  { (*NonLinear -> true | *) C.Unknown -> true }
  variant { ctx }
= match ctx with
  | Nil -> Nil, simp_eq g
  | Cons eq t -> let rt, rg = simp_ctx t g in
                 Cons (simp_eq eq) rt, rg
  end

let decision (l:context') (g:equality')
  requires { valid_ctx' l }
  requires { valid_eq' g }
  ensures { forall y z. result -> interp_ctx' l g y z }
  raises  { (* NonLinear -> true | *) C.Unknown -> true | Absurd -> true }
= let sl, sg = simp_ctx l g in
  linear_decision sl sg

797
end
798

799 800 801 802 803 804 805
module RationalCoeffs

use import int.Int
use import real.RealInfix
use import real.FromInt
use import int.Abs

806
(*meta coercion function from_int*)
807 808 809 810

type t = (int, int)
type rvars = int -> real

811
exception Unknown
812 813 814 815 816 817

let constant rzero = (0,1)
let constant rone = (1,1)

function rinterp (t:t) (v:rvars) : real
= match t with
818
  | (n,d) ->  from_int n /. from_int d
819 820
  end

821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852
let lemma prod_compat_eq (a b c:real)
  requires { c <> 0.0 }
  requires { a *. c = b *. c }
  ensures  { a = b }
= ()

let lemma cross_d (n1 d1 n2 d2:int)
  requires { d1 <> 0 /\ d2 <> 0 }
  requires { n1 * d2 = n2 * d1 }
  ensures { forall v. rinterp (n1,d1) v = rinterp (n2,d2) v }
= let d = from_int (d1 * d2) in
  assert { forall v. rinterp (n1, d1) v = rinterp (n2, d2) v
           by rinterp (n1, d1) v *. d = rinterp (n2,d2) v *. d }

let lemma cross_ind (n1 d1 n2 d2:int)
  requires { d1 <> 0 /\ d2 <> 0 }
  requires { forall v. rinterp (n1,d1) v = rinterp (n2,d2) v }
  ensures  { n1 * d2 = n2 * d1 }
= assert { from_int d1 <> 0.0 /\ from_int d2 <> 0.0 };
  assert { from_int n1 /. from_int d1 = from_int n2 /. from_int d2 };
  assert { from_int n1 *. from_int d2 = from_int n2 *. from_int d1
           by from_int n1 *. from_int d2
              = (from_int n1 /. from_int d1) *. from_int d1 *. from_int d2
              = (from_int n2 /. from_int d2) *. from_int d1 *. from_int d2
              = from_int n2 *. from_int d1 };
  assert { from_int (n1*d2) = from_int (n2 * d1) }


lemma cross: forall n1 d1 n2 d2: int. d1 <> 0 -> d2 <> 0 ->
             n1 * d2 = n2 * d1 <->
             forall v. rinterp (n1,d1) v = rinterp (n2,d2) v

853 854 855 856 857
use import int.ComputerDivision
use import ref.Ref
use import number.Gcd

let gcd (x:int) (y:int)
858
  requires { x > 0 /\ y > 0 }
859
  ensures { result = gcd x y }
860
  ensures { result > 0 }
861
  =
862
  let ghost ox = x in
863 864 865 866 867 868
  let x = ref x in let y = ref y in
  label Pre in
  while (!y > 0) do
     invariant { !x >= 0 /\ !y >= 0 }
     invariant { gcd !x !y = gcd (!x at Pre) (!y at Pre) }
     variant { !y }
869
     invariant { ox > 0 -> !x > 0 }
870 871 872 873 874
     let r = mod !x !y in let ghost q = div !x !y in
     assert { r = !x - q * !y };
     x := !y; y := r;
  done;
  !x
875

876 877 878 879
let simp (t:t) : t
  ensures { forall v:rvars. rinterp result v = rinterp t v }
= match t with
  | (n,d) ->
880 881 882
    if d = 0 then t
    else if n = 0 then rzero
    else
883
    let g = gcd (abs n) (abs d) in
884 885 886 887
    let n', d' = (div n g, div d g) in
    assert { n = g * n' /\ d = g * d' };
    assert { n' * d = n * d' };
    (n', d')
888
  end
889

890
let radd (a b:t)
891
  ensures { forall y. rinterp result y = rinterp a y +. rinterp b y }
892
  raises  { Unknown -> true }
893
= match (a,b) with
894
  | (n1,d1), (n2,d2) ->
895
  if d1 = 0 || d2 = 0 then raise Unknown
896 897 898 899 900 901 902 903 904 905 906 907
  else begin
    let r = (n1*d2 + n2*d1, d1*d2) in
    let ghost d = from_int d1 *. from_int d2 in
    assert { forall y.
             rinterp a y +. rinterp b y = rinterp r y
             by rinterp a y *. d = from_int n1 *. from_int d2
             so rinterp b y *. d = from_int n2 *. from_int d1
             so (rinterp a y +. rinterp b y) *. d
                = from_int (n1*d2 + n2*d1)
                = rinterp r y *. d };
    simp r end
 end
908 909

let rmul (a b:t)
910
  ensures { forall y. rinterp result y = rinterp a y *. rinterp b y }
911
  raises  { Unknown -> true }
912
= match (a,b) with
913
  | (n1,d1), (n2, d2) ->
914
    if d1 = 0 || d2 = 0 then raise Unknown
915 916 917 918 919 920 921 922 923
    else begin
      let r =  (n1*n2, d1*d2) in
      assert { forall y. rinterp r y = rinterp a y *. rinterp b y
               by rinterp r y = from_int (n1*n2) /. from_int(d1*d2)
                  = (from_int n1 *. from_int n2) /. (from_int d1 *. from_int d2)
                  = (from_int n1 /. from_int d1) *. (from_int n2 /. from_int d2)
                  = rinterp a y *. rinterp b y };
      simp r
    end
924 925 926
  end

let ropp (a:t)
927
  ensures { forall y. rinterp result y = -. rinterp a y }
928 929 930 931 932
= match a with
  | (n,d) -> (-n, d)
  end

let predicate req (a b:t)
933
  ensures { result -> forall y. rinterp a y = rinterp b y }
934
= match (a,b) with
935
  | (n1,d1), (n2,d2) -> n1 = n2 && d1 = d2 || (d1 <> 0 && d2 <> 0 && n1 * d2 = n2 * d1)
936 937
  end

938 939 940 941
let rinv (a:t)
  requires { not req a rzero }
  ensures { not req result rzero }
  ensures { forall y. rinterp result y *. rinterp a y = 1.0 }
942
  raises  { Unknown -> true }
943
= match a with
944
  | (n,d) -> if n = 0 || d = 0 then raise Unknown else (d,n)
945 946 947
  end

end
948

949 950 951
module LinearDecisionRational

use import RationalCoeffs
952 953 954
use import real.RealInfix
use import real.FromInt

955
clone export LinearEquationsDecision with type  C.a = real, function C.(+) = (+.), function C.( * ) = ( *. ), function C.(-_) = (-._), function C.(-) = (-.), type coeff = t, type C.cvars=int -> real, function C.interp=rinterp, exception C.Unknown = Unknown, constant C.azero = Real.zero, constant C.aone = Real.one, predicate C.ale = (<=.), val C.czero=rzero, val C.cone=rone, lemma C.sub_def, lemma C.zero_def, lemma C.one_def, val C.add=radd, val C.mul=rmul, val C.opp=ropp, val C.eq=req, val C.inv=rinv, goal C.A.ZeroLessOne, goal C.A.CompatOrderAdd, goal C.A.CompatOrderMult, goal C.A.Unitary, goal C.A.NonTrivialRing, goal C.A.Mul_distr_l, goal C.A.Mul_distr_r, goal C.A.Inv_def_l, goal C.A.Inv_def_r, goal C.A.MulAssoc.Assoc, goal C.A.Assoc, goal C.A.MulComm.Comm, goal C.A.Comm, goal C.A.Unit_def_l, goal C.A.Unit_def_r
956 957 958 959 960 961 962 963 964 965 966 967

end

module LinearDecisionInt

use import int.Int

function id (t:int) (v:int -> int) : int = t
let predicate eq (a b:int) = a=b

exception Unknown
let inv (t:int) : int
968
  (*ensures { forall v: int -> int. id result v * id t v = one }*)
969 970 971 972
  ensures { not (eq result zero) }
  raises { Unknown -> true }
= raise Unknown

973
clone export LinearEquationsDecision with type C.a = int, function C.(+)=(+), function C.(*) = (*), function C.(-_) = (-_), function C.(-) = (-), type coeff = int, type C.cvars = int->int,function C.interp = id, constant C.azero = zero, constant C.aone = one, predicate C.ale= (<=), val C.czero = zero, val C.cone = one, lemma C.sub_def, lemma C.zero_def, lemma C.one_def, val C.add = (+), val C.mul = (*), val C.opp = (-_), val C.eq = eq, val C.inv = inv, goal C.A.ZeroLessOne, goal C.A.CompatOrderAdd, goal C.A.CompatOrderMult, goal C.A.Unitary, goal C.A.NonTrivialRing, goal C.A.Mul_distr_l, goal C.A.Mul_distr_r, goal C.A.Inv_def_l, goal C.A.Inv_def_r, goal C.A.MulAssoc.Assoc, goal C.A.Assoc, goal C.A.MulComm.Comm, goal C.A.Comm, goal C.A.Unit_def_l, goal C.A.Unit_def_r
974 975 976 977 978 979 980 981 982 983 984


use import real.FromInt

use import RationalCoeffs
use LinearDecisionRational as R
use import list.List
let function m (x:int) : (int, int)
  ensures { forall z. rinterp result z = from_int x }
  = (x,1)

985
let ghost function m_y (y:int -> int): (int -> real)
986
  ensures { forall i. result i = from_int (y i) }
987
= fun i -> from_int (y i)
988

989 990 991
let rec function m_cprod (e:cprod) : R.cprod
  ensures { forall y z. R.interp_c result (m_y y) (m_y z)
            = from_int (interp_c e y z) }
992
= match e with
993 994
  | C c -> R.C (m c)
  | Times c1 c2 -> R.Times (m_cprod c1) (m_cprod c2)
995 996
  end

997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013
let rec function m_expr (e:expr') : R.expr'
  ensures { forall y z. R.interp' result (m_y y) (m_y z)
            = from_int (interp' e y z) }
  ensures { valid_expr' e -> R.valid_expr' result }
= match e with
  | Var i -> R.Var i
  | Coeff c -> R.Coeff (m c)
  | Sum e1 e2 -> R.Sum (m_expr e1) (m_expr e2)
  | Diff e1 e2 -> R.Diff (m_expr e1) (m_expr e2)
  | ProdL e c -> R.ProdL (m_expr e) (m_cprod c)
  | ProdR c e -> R.ProdR (m_cprod c) (m_expr e)
  end

let function m_eq (eq:equality') : R.equality'
  ensures { forall y z. R.interp_eq' result (m_y y) (m_y z)
                        <-> interp_eq' eq y z }
  ensures { valid_eq' eq -> R.valid_eq' result }
1014 1015
= match eq with (e1,e2) -> (m_expr e1, m_expr e2) end

1016 1017 1018 1019
let rec function m_ctx (ctx:context') : R.context'
  ensures { forall y z g. R.interp_ctx' result (m_eq g) (m_y y) (m_y z) <->
                        interp_ctx' ctx g y z }
  ensures { valid_ctx' ctx -> R.valid_ctx' result }
1020 1021 1022 1023 1024 1025 1026 1027
  variant { ctx }
= match ctx with
  | Nil -> Nil
  | Cons h t ->
    let r = Cons (m_eq h) (m_ctx t) in
    r
    end

1028 1029 1030 1031 1032 1033
let int_decision (l: context') (g: equality') : bool
  requires { valid_ctx' l }
  requires { valid_eq' g }
  ensures { forall y z. result -> interp_ctx' l g y z }
  raises  { R.Absurd -> true | (* R.NonLinear -> true | *) Unknown -> true }
= R.decision (m_ctx l) (m_eq g)
1034 1035 1036

end

1037 1038

module Test
1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051

use import RationalCoeffs
use import LinearDecisionRational
use import int.Int
use import real.RealInfix
use import real.FromInt

meta coercion function from_int

goal g: forall x y: real.
        (from_int 3 /. from_int 1) *. x +. (from_int 2/. from_int 1) *. y = (from_int 21/. from_int 1) ->
        (from_int 7 /. from_int 1) *. x +. (from_int 4/. from_int 1) *. y = (from_int 47/. from_int 1) ->
        x = (from_int 5 /. from_int 1)
1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063
end

module TestInt

use import LinearDecisionInt
use import int.Int

goal g: forall x y:int.
     3 * x + 2 * y = 21 ->
     7 * x + 4 * y = 47 ->
     x = 5

1064 1065 1066 1067 1068 1069 1070 1071 1072
end

module MP64Coeffs

use mach.int.UInt64 as M
use import real.RealInfix
use import real.FromInt
use import real.PowerReal
use RationalCoeffs as Q
1073
use import int.Int
1074

1075 1076
use import debug.Debug

1077 1078
type evars = int -> int

1079 1080

type exp = Lit int | Var int | Plus exp exp | Minus exp | Sub exp exp
1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097
type t = (Q.t, exp)

let constant mzero = (Q.rzero, Lit 0)
let constant mone = (Q.rone, Lit 0)

constant rradix: real = from_int (M.radix)

function qinterp (q:Q.t) : real
= match q with (n,d) -> from_int n /. from_int d end

lemma qinterp_def: forall q v. qinterp q = Q.rinterp q v

function interp_exp (e:exp) (y:evars) : int
= match e with
  | Lit n -> n
  | Var v -> y v
  | Plus e1 e2 -> interp_exp e1 y + interp_exp e2 y
1098
  | Sub e1 e2 -> interp_exp e1 y - interp_exp e2 y
1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115
  | Minus e' -> - (interp_exp e' y)
  end
(*
function interp_pow (n:int) : real
= if n >= 0 then from_int (power M.radix n)
  else inv (from_int (power M.radix (-n)))

lemma Pow_sum: forall m n: int. interp_pow (m+n) = interp_pow m *. interp_pow n
*)
function minterp (t:t) (y:evars) : real
= match t with
  (q,e) ->
  qinterp q *. pow rradix (from_int (interp_exp e y))
  end

exception Unknown

1116 1117 1118 1119 1120 1121 1122 1123 1124 1125
let rec opp_exp (e:exp)
  ensures { forall y. interp_exp result y = - interp_exp e y }
  variant { e }
= match e with
  | Lit n -> Lit (-n)
  | Minus e' -> e'
  | Plus e1 e2 -> Plus (opp_exp e1) (opp_exp e2)
  | Sub e1 e2 -> Sub e2 e1
  | Var _ -> Minus e
  end
1126

1127 1128 1129 1130 1131
let rec add_sub_exp (e1 e2:exp) (s:bool) : exp
  ensures { forall y.
            if s
            then interp_exp result y = interp_exp e1 y + interp_exp e2 y
            else interp_exp result y = interp_exp e1 y - interp_exp e2 y }
1132
  raises  { Unknown -> true }
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
1133
  variant { e2, e1 }
1134
=
1135 1136 1137 1138
  let rec add_atom (e a:exp) (s:bool) : (exp, bool)
    returns { r, b -> forall y.
              if s then interp_exp r y = interp_exp e y + interp_exp a y
                   else interp_exp r y = interp_exp e y - interp_exp a y }
1139 1140 1141
    raises { Unknown -> true }
    variant { e }
  = match (e,a) with
1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154
    | Lit n1, Lit n2 -> (if s then Lit (n1+n2) else Lit (n1-n2)), true
    | Lit n, Var i
      -> if n = 0 then (if s then Var i else Minus (Var i)), true
         else (if s then Plus e a else Sub e a), False
    | Var i, Lit n
      -> if n = 0 then Var i, true
      else (if s then Plus e a else Sub e a), False
    | Lit n, Minus e' ->
      if n = 0 then (if s then Minus e' else e'), True
      else (if s then Plus e a else Sub e a), False
    | Minus e', Lit n ->
      if n = 0 then Minus e', True
      else (if s then Plus e a else Sub e a), False
1155
    | Var i, Minus (Var j) | Minus (Var j), Var i ->
1156 1157 1158 1159 1160 1161 1162
      if s && (i = j) then (Lit 0, true)
      else (if s then Plus e a else Sub e a), False
    | Var i, Var j -> if s then Plus e a, False
                      else
                        if i = j then Lit 0, True
                        else Sub e a, False
    | Minus _, Minus _ -> (if s then Plus e a else Sub e a), False
1163
    | Plus e1 e2, _ ->
1164
      let r, b = add_atom e1 a s in
Raphael Rieu-Helft's avatar
Raphael Rieu-Helft committed
1165 1166 1167 1168 1169
      if b then
        match r with
        | Lit n -> if n = 0 then e2, True else Plus r e2, True
        | _ -> Plus r e2, True
        end
1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181
      else let r, b = add_atom e2 a s in Plus e1 r, b
    | Sub e1 e2, _ ->
      let r, b = add_atom e1 a s in
      if b then
        match r with
        | Lit n -> if n = 0 then opp_exp e2, True else Sub r e2, True
        | _ -> Sub r e2, True
        end
      else let r, b = add_atom e2 a (not s) in
           if b then Sub e1 r, True
           else if s then Sub (Plus e1 a) e2, False
                else Sub e1 (Plus e2 a), False
1182 1183 1184 1185 1186
    | _ -> raise Unknown
    end
  in
  match e2 with
   | Plus e1' e2' ->
1187
     let r = add_sub_exp e1 e1' s in
1188
     match r with
1189 1190 1191 1192
     | Lit n -> if n = 0
                then (if s then e2' else opp_exp e2')
                else add_sub_exp r e2' s
     | _ -> add_sub_exp r e2' s
1193
     end
1194 1195 1196 1197 1198 1199 1200 1201 1202
   | Sub e1' e2' ->
     let r = add_sub_exp e1 e1' s in
     match r with
     | Lit n -> if n = 0
                then (if s then opp_exp e2' else e2')
                else add_sub_exp r e2' (not s)
     | _ -> add_sub_exp r e2' (not s)
     end
   | _ -> let r, _ = add_atom e1 e2 s in r
1203 1204
  end

1205 1206 1207 1208 1209 1210
let add_exp (e1 e2:exp) : exp
  ensures { forall y. interp_exp result y = interp_exp e1 y + interp_exp e2 y }
  raises  { Unknown -> True }
= add_sub_exp e1 e2 True


1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223
let rec zero_exp (e:exp) : bool
  ensures { result -> forall y. interp_exp e y = 0 }
  variant { e }
  raises  { Unknown -> true }
=
  let rec all_zero (e:exp) : bool
    ensures { result -> forall y. interp_exp e y = 0 }
    variant { e }
  = match e with
    | Lit n -> n = 0
    | Var _ -> false
    | Minus e -> all_zero e
    | Plus e1 e2 -> all_zero e1 && all_zero e2
1224
    | Sub e1 e2 -> all_zero e1 && all_zero e2
1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240
    end
  in
  let e' = add_exp (Lit 0) e in (* simplifies exp *)
  all_zero e'

let rec same_exp (e1 e2: exp)
  ensures { result -> forall y. interp_exp e1 y = interp_exp e2 y }
  variant { e1, e2 }
  raises  { Unknown -> true }
= match e1, e2 with
  | Lit n1, Lit n2 -> n1 = n2
  | Var v1, Var v2 -> v1 = v2
  | Minus e1', Minus e2' -> same_exp e1' e2'
  | _ -> zero_exp (add_exp e1 (opp_exp e2))
  end

1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258
let madd (a b:t)
  ensures { forall y. minterp result y = minterp a y +. minterp b y }
  raises  { Unknown -> true }
  raises  { Q.Unknown -> true }
= match a, b with
  | (q1, e1), (q2, e2) ->
    if Q.req q1 Q.rzero then b
    else if Q.req q2 Q.rzero then a
    else if same_exp e1 e2
    then begin
      let q = Q.radd q1 q2 in
      assert { forall y. minterp (q, e1) y = minterp a y +. minterp b y
               by let p = pow rradix (from_int (interp_exp e1 y)) in
                  minterp (q, e1) y = (qinterp q) *. p
                  = (qinterp q1 +. qinterp q2) *. p
                  = qinterp q1 *. p +. qinterp q2 *. p
                  = minterp a y +. minterp b y };
      (q,e1) end
1259
    else (print a; print b; raise Unknown)
1260 1261 1262 1263 1264 1265 1266 1267 1268
  end

let mmul (a b:t)
  ensures { forall y. minterp result y = minterp a y *. minterp b y }
  raises  { Q.Unknown -> true }
  raises  { Unknown -> true }
= match a, b with
  | (q1,e1), (q2,e2) ->
    let q = Q.rmul q1 q2 in
1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283
    if Q.req q Q.rzero then mzero
    else begin
      let e = add_exp e1 e2 in
      assert { forall y. minterp (q,e) y = minterp a y *. minterp b y
               by let p1 = pow rradix (from_int (interp_exp e1 y)) in
                  let p2 = pow rradix (from_int (interp_exp e2 y)) in
                  let p  = pow rradix (from_int (interp_exp e y)) in
                  interp_exp e y = interp_exp e1 y + interp_exp e2 y
                  so p = p1 *. p2
                  so minterp (q,e) y = qinterp q *. p
                     = (qinterp q1 *. qinterp q2) *. p
                     = (qinterp q1 *. qinterp q2) *. p1 *. p2
                     = minterp a y *. minterp b y };
      (q,e)
    end
1284 1285 1286 1287 1288 1289
  end

let mopp (a:t)
  ensures { forall y. minterp result y = -. minterp a y }
= match a with (q,e) -> (Q.ropp q, e) end

1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302
let rec predicate pure_same_exp (e1 e2: exp)
  ensures { result -> forall y. interp_exp e1 y = interp_exp e2 y }
  variant { e1, e2 }
= match e1, e2 with
  | Lit n1, Lit n2 -> n1 = n2
  | Var v1, Var v2 -> v1 = v2
  | Minus e1', Minus e2' -> pure_same_exp e1' e2'
  | Plus a1 a2, Plus b1 b2 ->
    (pure_same_exp a1 b1 && pure_same_exp a2 b2) ||
    (pure_same_exp a1 b2 && pure_same_exp a2 b1)
  | _ -> false
  end

1303 1304 1305
let predicate meq (a b:t)
  ensures { result -> forall y. minterp a y = minterp b y }
= match (a,b) with