module MatrixMultiplication use import int.Int use import int.Sum use import map.Map use import matrix.Matrix function mul_atom (a b: matrix int) (i j: int) : int -> int = fun k -> a.elts[i][k] * b.elts[k][j] predicate matrix_product (m a b: matrix int) = forall i j. 0 <= i < m.rows -> 0 <= j < m.columns -> m.elts[i][j] = sum (mul_atom a b i j) 0 a.columns let mult_naive (a b: matrix int) : matrix int requires { a.columns = b.rows } ensures { result.rows = a.rows /\ result.columns = b.columns } ensures { matrix_product result a b } = let rs = make (rows a) (columns b) 0 in for i = 0 to a.rows - 1 do invariant { forall i0 j0. i <= i0 < rows a /\ 0 <= j0 < columns b -> rs.elts[i0][j0] = 0 } invariant { forall i0 j0. 0 <= i0 < i /\ 0 <= j0 < columns b -> rs.elts[i0][j0] = sum (mul_atom a b i0 j0) 0 a.columns } label M in for k = 0 to rows b - 1 do invariant { forall i0 j0. 0 <= i0 < rows a /\ 0 <= j0 < columns b -> i0 <> i -> rs.elts[i0][j0] = (rs at M).elts[i0][j0] } invariant { forall j0. 0 <= j0 < columns b -> rs.elts[i][j0] = sum (mul_atom a b i j0) 0 k } label I in for j = 0 to columns b - 1 do invariant { forall i0 j0. 0 <= i0 < rows a /\ 0 <= j0 < columns b -> (i0 <> i \/ j0 >= j) -> rs.elts[i0][j0] = (rs at I).elts[i0][j0] } invariant { forall j0. 0 <= j0 < j -> rs.elts[i][j0] = sum (mul_atom a b i j0) 0 (k+1) } set rs i j (get rs i j + get a i k * get b k j) done; done; done; rs end