matrices_ring_simp.mlw 15.8 KB
Newer Older
1 2 3
(* Symbolic computation on matrix expressions. *)
module Symb

4 5 6 7
  use int.Int
  use list.List
  use matrices.Matrix
  use matrices.MatrixArithmetic
8

9
  scope LOCAL
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36

    type mono = {
      m_prod : list int;
      m_pos : bool;
    }

    function l_mdl (f:int -> mat int) (l:list int) : mat int =
      match l with
      | Nil -> zero (-1) (-1)
      | Cons x Nil -> f x
      | Cons x q -> mul (f x) (l_mdl f q)
      end
    meta rewrite_def function l_mdl
    predicate l_vld (f:int -> mat int) (r c:int) (l:list int) =
      match l with
      | Nil -> false
      | Cons x Nil -> rows (f x) = r && cols (f x) = c
      | Cons x q -> rows (f x) = r && l_vld f (cols (f x)) c q
      end
    meta rewrite_def predicate l_vld
    let rec lemma l_mdl_ok (f:int -> mat int) (r c:int) (l:list int) : unit
      requires { l_vld f r c l }
      ensures { let rs = l_mdl f l in rows rs = r /\ cols rs = c }
      variant { l }
    = match l with
      | Nil -> absurd
      | Cons _ Nil -> ()
37
      | Cons x q -> l_mdl_ok f (cols (f x)) c q
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
      end
    function m_mdl (f:int -> mat int) (m:mono) : mat int =
      let m0 = l_mdl f m.m_prod in
      if m.m_pos then m0 else opp m0
    meta rewrite_def function m_mdl
    let lemma m_mdl_ok (f:int -> mat int) (r c:int) (m:mono) : unit
      requires { l_vld f r c m.m_prod }
      ensures { let rs = m_mdl f m in rows rs = r /\ cols rs = c }
    = l_mdl_ok f r c m.m_prod

    function lm_mdl (f:int -> mat int) (r c:int) (l:list mono) : mat int =
      match l with
      | Nil -> zero r c
      | Cons x q -> add (lm_mdl f r c q) (m_mdl f x)
      end
    function lm_mdl_simp (f:int -> mat int) (r c:int) (l:list mono) : mat int =
      match l with
      | Nil -> zero r c
      | Cons x Nil -> m_mdl f x
      | Cons x q -> add (lm_mdl_simp f r c q) (m_mdl f x)
      end
    meta rewrite_def function lm_mdl_simp
    predicate lm_vld (f:int -> mat int) (r c:int) (l:list mono) =
      match l with
      | Nil -> true
      | Cons x q -> l_vld f r c x.m_prod && lm_vld f r c q
      end
    meta rewrite_def predicate lm_vld
    let rec lemma lm_mdl_ok (f:int -> mat int) (r c:int) (l:list mono) : unit
      requires { lm_vld f r c l /\ r >= 0 /\ c >= 0 }
      ensures { let rs = lm_mdl f r c l in rows rs = r /\ cols rs = c }
      variant { l }
    = match l with
      | Nil -> ()
      | Cons _ q -> lm_mdl_ok f r c q
      end
    let rec ghost lm_mdl_same (f:int -> mat int) (r c:int) (l:list mono) : unit
      requires { lm_vld f r c l }
      ensures { lm_mdl_simp f r c l = lm_mdl f r c l }
      variant { l }
    = match l with
      | Nil -> ()
      | Cons _ Nil -> ()
      | Cons _ q -> lm_mdl_same f r c q
      end

    function l_compare (l1 l2:list int) : int = match l1, l2 with
      | Nil, Nil -> 0
      | Nil, _ -> (-1)
      | _, Nil -> 1
      | Cons x q, Cons y r ->
        if x < y then (-1) else if x > y then 1 else l_compare q r
      end
    meta rewrite_def function l_compare
    let rec ghost l_compare_zero (l1 l2:list int) : unit
      requires { l_compare l1 l2 = 0 }
      ensures { l1 = l2 }
      variant { l1 }
    = match l1, l2 with
      | Nil, Nil -> ()
      | Cons _ q, Cons _ r -> l_compare_zero q r
      | _ -> absurd
      end

    predicate m_lower (m1 m2:mono) =
      let cmp = l_compare m1.m_prod m2.m_prod in
      cmp < 0 || (cmp = 0 && (m1.m_pos -> m2.m_pos))
    meta rewrite_def predicate m_lower

    function m_collapse (l:list mono) (m:mono) : list mono = match l with
    | Nil -> Cons m Nil
    | Cons x q ->
      if not x.m_pos && m.m_pos && l_compare m.m_prod x.m_prod = 0
      then q
      else Cons m l
    end
    meta rewrite_def function m_collapse
    let ghost m_collapse_ok (f:int -> mat int) (r c:int)
                            (l:list mono) (m:mono) : list mono
      requires { lm_vld f r c l /\ l_vld f r c m.m_prod }
      requires { r >= 0 /\ c >= 0 }
      ensures { result = m_collapse l m /\ lm_vld f r c result }
      ensures { lm_mdl f r c result = add (lm_mdl f r c l) (m_mdl f m) }
121
    = let res = m_collapse l m in
122 123 124
      match l with
      | Nil -> ()
      | Cons x q ->
125
        if not x.m_pos && m.m_pos && l_compare m.m_prod x.m_prod = 0
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
        then begin
          l_compare_zero m.m_prod x.m_prod;
          assert { lm_mdl f r c q =
            add (m_mdl f m) (add (lm_mdl f r c q) (m_mdl f x)) }
          end
      end;
      res

    function lm_collapse (acc l:list mono) : list mono = match l with
      | Nil -> acc
      | Cons x q -> lm_collapse (m_collapse acc x) q
      end
    meta rewrite_def function lm_collapse
    let rec ghost lm_collapse_ok (f:int -> mat int) (r c:int)
                                 (acc l:list mono) : list mono
      requires { lm_vld f r c acc /\ lm_vld f r c l }
      requires { r >= 0 /\ c >= 0 }
      ensures { result = lm_collapse acc l /\ lm_vld f r c result }
      ensures { lm_mdl f r c result = add (lm_mdl f r c acc) (lm_mdl f r c l) }
      variant { l }
    = match l with
      | Nil -> acc
      | Cons x q -> lm_collapse_ok f r c (m_collapse_ok f r c acc x) q
      end

    function cat_rev (acc l:list 'a) : list 'a = match l with
      | Nil -> acc
      | Cons x q -> cat_rev (Cons x acc) q
      end
    meta rewrite_def function cat_rev
    let rec ghost cat_rev_ok (f:int -> mat int) (r c:int)
      (acc l:list mono) : list mono
      requires { lm_vld f r c acc /\ lm_vld f r c l }
      requires { r >= 0 /\ c >= 0 }
      ensures { result = cat_rev acc l /\ lm_vld f r c result }
      ensures { lm_mdl f r c result = add (lm_mdl f r c acc) (lm_mdl f r c l) }
      variant { l }
    = match l with
      | Nil -> acc
      | Cons x q -> cat_rev_ok f r c (Cons x acc) q
      end

    function lm_dump (x:mono) (acc l:list mono) : (list mono,list mono) =
      match l with
      | Nil -> (acc, Nil)
      | Cons y q ->
        if m_lower x y then (acc, l) else lm_dump x (m_collapse acc y) q
      end
    meta rewrite_def function lm_dump
    let rec ghost lm_dump_ok (f:int -> mat int) (r c:int)
      (x:mono) (acc l:list mono) : (list mono,list mono)
      requires { r >= 0 /\ c >= 0 }
      requires { lm_vld f r c acc /\ lm_vld f r c l }
      ensures { result = lm_dump x acc l }
      ensures { let (acc2,l2) = result in
        lm_vld f r c acc2 /\ lm_vld f r c l2 /\
        add (lm_mdl f r c acc2) (lm_mdl f r c l2) =
        add (lm_mdl f r c acc) (lm_mdl f r c l) }
      variant { l }
    = match l with
      | Nil -> (acc, Nil)
      | Cons y q ->
188
        if m_lower x y then (acc, l)
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
        else lm_dump_ok f r c x (m_collapse_ok f r c acc y) q
      end

    function lm_merge (acc l1 l2:list mono) : list mono = match l1 with
      | Nil -> cat_rev Nil (lm_collapse acc l2)
      | Cons x q ->
        let (acc,l2) = lm_dump x acc l2 in
        lm_merge (m_collapse acc x) q l2
      end
    meta rewrite_def function lm_merge
    let rec ghost lm_merge_ok (f:int -> mat int) (r c:int)
      (acc l1 l2:list mono) : list mono
      requires { r >= 0 /\ c >= 0 /\ lm_vld f r c acc }
      requires { lm_vld f r c l1 /\ lm_vld f r c l2 }
      ensures { result = lm_merge acc l1 l2 /\ lm_vld f r c result }
      ensures { lm_mdl f r c result =
        add (lm_mdl f r c acc) (add (lm_mdl f r c l1) (lm_mdl f r c l2)) }
      variant { l1 }
    = match l1 with
      | Nil -> cat_rev_ok f r c Nil (lm_collapse_ok f r c acc l2)
      | Cons x q -> let (acc,l2) = lm_dump_ok f r c x acc l2 in
        lm_merge_ok f r c (m_collapse_ok f r c acc x) q l2
      end

    function cat (l1 l2:list 'a) : list 'a = match l1 with
      | Nil -> l2
      | Cons x q -> Cons x (cat q l2)
      end
    meta rewrite_def function cat

    let rec ghost cat_ok (f:int -> mat int) (r k c:int)
                         (l1 l2:list int) : list int
      requires { r >= 0 /\ k >= 0 /\ c >= 0 }
      requires { l_vld f r k l1 /\ l_vld f k c l2 }
      ensures { result = cat l1 l2 /\ l_vld f r c result }
      ensures { l_mdl f result = mul (l_mdl f l1) (l_mdl f l2) }
      variant { l1 }
    = match l1, l2 with
      | Nil, _ | _, Nil -> absurd
      | Cons x Nil, _ -> Cons x l2
229
      | Cons x q, _ -> Cons x (cat_ok f (cols (f x)) k c q l2)
230 231 232 233 234 235 236 237 238 239
      end

    function m_mul (m1 m2:mono) : mono =
      { m_pos = (m1.m_pos <-> m2.m_pos); m_prod = cat m1.m_prod m2.m_prod }
    meta rewrite_def function m_mul
    let ghost m_mul_ok (f:int -> mat int) (r k c:int) (m1 m2:mono) : mono
      requires { r >= 0 /\ k >= 0 /\ c >= 0 }
      requires { l_vld f r k m1.m_prod /\ l_vld f k c m2.m_prod }
      ensures { result = m_mul m1 m2 /\ l_vld f r c result.m_prod }
      ensures { m_mdl f result = mul (m_mdl f m1) (m_mdl f m2) }
240
    = { m_pos = (if m1.m_pos then m2.m_pos else not m2.m_pos);
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
        m_prod = cat_ok f r k c m1.m_prod m2.m_prod }

    function m_distribute (m:mono) (l:list mono) : list mono = match l with
      | Nil -> Nil
      | Cons x q -> Cons (m_mul m x) (m_distribute m q)
      end
    meta rewrite_def function m_distribute
    let rec ghost m_distribute_ok (f:int -> mat int) (r k c:int)
                                  (m:mono) (l:list mono) : list mono
      requires { r >= 0 /\ k >= 0 /\ c >= 0 }
      requires { l_vld f r k m.m_prod /\ lm_vld f k c l }
      ensures { result = m_distribute m l /\ lm_vld f r c result }
      ensures { lm_mdl f r c result = mul (m_mdl f m) (lm_mdl f k c l) }
      variant { l }
    = match l with
      | Nil -> Nil
      | Cons x q -> Cons (m_mul_ok f r k c m x) (m_distribute_ok f r k c m q)
      end

    function lm_distribute (l1 l2:list mono) : list mono = match l1 with
      | Nil -> Nil
      | Cons x q -> lm_merge Nil (m_distribute x l2) (lm_distribute q l2)
      end
    meta rewrite_def function lm_distribute
    let rec ghost lm_distribute_ok (f:int -> mat int) (r k c:int)
                                   (l1 l2:list mono) : list mono
      requires { r >= 0 /\ k >= 0 /\ c >= 0 }
      requires { lm_vld f r k l1 /\ lm_vld f k c l2 }
      ensures { result = lm_distribute l1 l2 /\ lm_vld f r c result }
      ensures { lm_mdl f r c result = mul (lm_mdl f r k l1) (lm_mdl f k c l2) }
      variant { l1 }
    = match l1 with
      | Nil -> Nil
      | Cons x q ->
        lm_merge_ok f r c Nil (m_distribute_ok f r k c x l2)
                              (lm_distribute_ok f r k c q l2)
      end

    function lm_opp (l:list mono) : list mono = match l with
    | Nil -> Nil
    | Cons x q ->
        lm_merge Nil (Cons { x with m_pos = not x.m_pos } Nil)
                     (lm_opp q)
    end
    meta rewrite_def function lm_opp
    let rec ghost lm_opp_ok (f:int -> mat int) (r c:int)
                            (l:list mono) : list mono
      requires { r >= 0 /\ c >= 0 /\ lm_vld f r c l }
      ensures { result = lm_opp l /\ lm_vld f r c result }
      ensures { lm_mdl f r c result = opp (lm_mdl f r c l) }
      variant { l }
    = match l with
      | Nil -> Nil
      | Cons x q -> let m2 = { x with m_pos = not x.m_pos } in
        lm_merge_ok f r c Nil (Cons m2 Nil) (lm_opp_ok f r c q);
      end

    constant empty : int -> mat int = zero 0
  end

  type env = {
    mutable ev_f : int -> mat int;
    mutable ev_c : int;
  }

  type expr = {
    e_body : list LOCAL.mono;
    e_rows : int;
    e_cols : int;
  }

312
  predicate (-->) (x y:'a) = [@rewrite] x = y
313 314 315 316 317 318 319 320 321 322 323
  meta rewrite_def predicate (-->)

  predicate e_vld (env:env) (e:expr) =
    e.e_rows >= 0 /\ e.e_cols >= 0 /\
    LOCAL.lm_vld env.ev_f e.e_rows e.e_cols e.e_body
  meta rewrite_def predicate e_vld
  function e_mdl (env:env) (e:expr) : mat int =
    LOCAL.lm_mdl_simp env.ev_f e.e_rows e.e_cols e.e_body
  meta rewrite_def function e_mdl

  function extends (f:int -> mat int) (c:int) (v:mat int) : int -> mat int =
324
    fun n -> if n <> c then f n else v
325 326
  lemma extends_rw : forall f c v n.
    extends f c v n = if n <> c then f n else v
327
  meta rewrite lemma extends_rw
328 329 330

  let ghost symb_env () : env
    ensures { result.ev_c --> 0 }
331
  = { ev_f = LOCAL.empty; ev_c = 0 }
332 333 334 335 336 337 338 339 340

  function symb_mat (m:mat int) (n:int) : expr =
    { e_rows = m.rows; e_cols = m.cols;
      e_body = Cons { LOCAL.m_pos = true; LOCAL.m_prod = Cons n Nil } Nil }
  meta rewrite_def function symb_mat
  let ghost symb_reg (ghost env:env) (m:mat int) : expr
    writes { env }
    ensures { env.ev_c --> old env.ev_c + 1 }
    ensures { env.ev_f --> extends (old env.ev_f) (old env.ev_c) m }
Martin Clochard's avatar
Martin Clochard committed
341
    ensures { result --> symb_mat m (old env.ev_c) }
342 343
    ensures { e_vld env result }
  = let id = env.ev_c in
344 345
    env.ev_f <- extends env.ev_f id m; env.ev_c <- id + 1;
    symb_mat m id
346 347 348 349 350 351 352 353

  function symb_opp (e:expr) : expr =
    { e_rows = e.e_rows; e_cols = e.e_cols; e_body = LOCAL.lm_opp e.e_body }
  meta rewrite_def function symb_opp

  let ghost symb_opp (env:env) (e:expr) : expr
    requires { e_vld env e }
    ensures { e_vld env result }
Martin Clochard's avatar
Martin Clochard committed
354
    ensures { result --> symb_opp e }
355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
    ensures { e_mdl env result = opp (e_mdl env e) }
  = let (r,c) = (e.e_rows,e.e_cols) in
    LOCAL.lm_mdl_same env.ev_f r c e.e_body;
    let res = { e_rows = r; e_cols = c;
      e_body = LOCAL.lm_opp_ok env.ev_f r c e.e_body } in
    LOCAL.lm_mdl_same env.ev_f r c res.e_body;
    res

  function symb_add (e1 e2:expr) : expr =
    { e_rows = e1.e_rows; e_cols = e1.e_cols;
      e_body = LOCAL.lm_merge Nil e1.e_body e2.e_body }
  meta rewrite_def function symb_add

  let ghost symb_add (env:env) (e1 e2:expr) : expr
    requires { e_vld env e1 /\ e_vld env e2 }
    requires { e1.e_rows = e2.e_rows /\ e1.e_cols = e2.e_cols }
    ensures { e_vld env result }
Martin Clochard's avatar
Martin Clochard committed
372
    ensures { result --> symb_add e1 e2 }
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
    ensures { e_mdl env result = add (e_mdl env e1) (e_mdl env e2) }
  = let (r,c) = (e1.e_rows, e1.e_cols) in
    LOCAL.lm_mdl_same env.ev_f r c e1.e_body;
    LOCAL.lm_mdl_same env.ev_f r c e2.e_body;
    let res = { e_rows = r; e_cols = c;
        e_body = LOCAL.lm_merge_ok env.ev_f r c Nil e1.e_body e2.e_body } in
    LOCAL.lm_mdl_same env.ev_f r c res.e_body;
    res

  function symb_sub (e1 e2:expr) : expr =
    symb_add e1 (symb_opp e2)
  meta rewrite_def function symb_sub

  let ghost symb_sub (env:env) (e1 e2:expr) : expr
    requires { e_vld env e1 /\ e_vld env e2 }
    requires { e1.e_rows = e2.e_rows /\ e1.e_cols = e2.e_cols }
    ensures { e_vld env result }
Martin Clochard's avatar
Martin Clochard committed
390
    ensures { result --> symb_sub e1 e2 }
391 392 393 394 395 396 397 398 399 400 401 402
    ensures { e_mdl env result = sub (e_mdl env e1) (e_mdl env e2) }
  = symb_add env e1 (symb_opp env e2)

  function symb_mul (e1 e2:expr) : expr =
    { e_rows = e1.e_rows; e_cols = e2.e_cols;
      e_body = LOCAL.lm_distribute e1.e_body e2.e_body }
  meta rewrite_def function symb_mul

  let ghost symb_mul (env:env) (e1 e2:expr) : expr
    requires { e_vld env e1 /\ e_vld env e2 }
    requires { e1.e_cols = e2.e_rows }
    ensures { e_vld env result }
Martin Clochard's avatar
Martin Clochard committed
403
    ensures { result --> symb_mul e1 e2 }
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
    ensures { e_mdl env result = mul (e_mdl env e1) (e_mdl env e2) }
  = let (r,k,c) = (e1.e_rows,e1.e_cols,e2.e_cols) in
    LOCAL.lm_mdl_same env.ev_f r k e1.e_body;
    LOCAL.lm_mdl_same env.ev_f k c e2.e_body;
    let res = { e_rows = r; e_cols = c;
        e_body = LOCAL.lm_distribute_ok env.ev_f r k c e1.e_body e2.e_body } in
    LOCAL.lm_mdl_same env.ev_f r c res.e_body;
    res

  meta compute_max_steps 0x1000000

  let harness (a11 a12 a22 b11 b21 b22:mat int) : unit
    requires { a11 === a12 === a22 === b11 === b21 === b22 }
    requires { a11.rows = a11.cols }
  = let env = symb_env () in
    let e_a11 = symb_reg env a11 in
    let e_a12 = symb_reg env a12 in
    let e_a22 = symb_reg env a22 in
    let e_b11 = symb_reg env b11 in
    let e_b21 = symb_reg env b21 in
    let e_b22 = symb_reg env b22 in
    let x1 = symb_mul env (symb_add env e_a11 e_a22)
                          (symb_add env e_b11 e_b22) in
    let x4 = symb_mul env e_a22 (symb_sub env e_b21 e_b11) in
    let x5 = symb_mul env (symb_add env e_a11 e_a12) e_b22 in
    let x7 = symb_mul env (symb_sub env e_a12 e_a22)
                          (symb_add env e_b21 e_b22) in
    let m11 = symb_add env (symb_sub env (symb_add env x1 x4) x5) x7 in
    let gm11 = symb_add env (symb_mul env e_a11 e_b11)
                            (symb_mul env e_a12 e_b21) in
    assert { e_mdl env m11 = e_mdl env gm11 }

end