naive.mlw 1.59 KB
Newer Older
1 2 3 4 5 6 7 8
module MatrixMultiplication

  use import int.Int
  use import int.Sum
  use import map.Map
  use import matrix.Matrix

  function mul_atom (a b: matrix int) (i j: int) : int -> int =
9
    fun k -> a.elts[i][k] * b.elts[k][j]
10 11
  predicate matrix_product (m a b: matrix int) =
    forall i j. 0 <= i < m.rows -> 0 <= j < m.columns ->
12
      m.elts[i][j] = sum (mul_atom a b i j) 0 a.columns
13 14 15 16 17 18

  let mult_naive (a b: matrix int) : matrix int
    requires { a.columns = b.rows }
    ensures { result.rows = a.rows /\ result.columns = b.columns }
    ensures { matrix_product result a b }
  = let rs = make (rows a) (columns b) 0 in
19
    for i = 0 to a.rows - 1 do
20 21 22
      invariant { forall i0 j0. i <= i0 < rows a /\ 0 <= j0 < columns b ->
        rs.elts[i0][j0] = 0 }
      invariant { forall i0 j0. 0 <= i0 < i /\ 0 <= j0 < columns b ->
23 24
        rs.elts[i0][j0] = sum (mul_atom a b i0 j0) 0 a.columns }
      label M in
25 26
      for k = 0 to rows b - 1 do
        invariant { forall i0 j0. 0 <= i0 < rows a /\ 0 <= j0 < columns b ->
27
          i0 <> i -> rs.elts[i0][j0] = (rs at M).elts[i0][j0] }
28
        invariant { forall j0. 0 <= j0 < columns b ->
29 30
          rs.elts[i][j0] = sum (mul_atom a b i j0) 0 k }
        label I in
31 32
        for j = 0 to columns b - 1 do
          invariant { forall i0 j0. 0 <= i0 < rows a /\ 0 <= j0 < columns b ->
33
            (i0 <> i \/ j0 >= j) -> rs.elts[i0][j0] = (rs at I).elts[i0][j0] }
34
          invariant { forall j0. 0 <= j0 < j ->
35
             rs.elts[i][j0] = sum (mul_atom a b i j0) 0 (k+1) }
36 37 38 39 40 41 42 43
          set rs i j (get rs i j + get a i k * get b k j)
        done;
      done;
    done;
    rs


end