Commit 8a4341d1 authored by Martin Clochard's avatar Martin Clochard
Browse files

new example: matrix multiplication

parent 32aac69d
This is a complete solution to Verifythis 2016 challenge 1.
A self-contained solution to task 1 can be found in file naive.mlw
Other files solve task 2 and 3:
- sum_extended.mlw prove a few extra lemmas on top of what
the standard library already gives
- matrices.mlw contains the theories of matrices, matrix arithmetic,
and block product. Except for providing the requested program,
it solves task 2.
- matrices_ring_simp.mlw is a support file for proof by reflection
of matrix algebraic equations
- strassen.mlw is the implementation of Strassen's Algorithm for task 3.
It also contains the program associated to task 2 (assoc_proof).
To replay proofs:
- Install Why3 development version from the git source repository
(at the time this file was written, it would not replay with the
release version due to known incompleteness bug in compute)
- Install SMT solvers Alt-Ergo 1.01, CVC4 1.4 and Z3 4.4.1
- Run why3 replay -L . FILE to replay the session associated to FILE.
Challenge text:
Challenge 1: Matrix Multiplication
Consider the following pseudocode algorithm, which is naive implementation of matrix multiplication. For simplicity we assume that the matrices are square.
int[][] matrixMultiply(int[][] A, int[][] B) {
int n = A.length;
// initialise C
int[][] C = new int[n][n];
for (int i = 0; i < n; i++) {
for (int k = 0; k < n; k++) {
for (int j = 0; j < n; j++) {
C[i][j] += A[i][k] * B[k][j];
}
}
}
return C;
}
Tasks.
1. Provide a specification to describe the behaviour of this algorithm, and prove
that it correctly implements its specification.
2. Show that matrix multiplication is associative, i.e., the order in which
matrices are multiplied can be disregarded: A(BC) = (AB)C. To show this,
you should write a program that performs the two different computations,
and then prove that the result of the two computations is always the same.
3. [Optional, if time permits] In the literature, there exist many proposals
for more efficient matrix multiplication algorithms. Strassen’s algorithm
was one of the first. The key idea of the algorithm is to use a recursive
algorithm that reduces the number of multiplications on submatrices
(from 8 to 7), see https://en.wikipedia.org/wiki/Strassen_algorithm for an
explanation. A relatively clean Java implementation (and Python and C++)
can be found here: https://martin-thoma.com/strassen-algorithm-in-python-java-cpp/.
Prove that the naive algorithm above has the same behaviour as Strassen’s
algorithm. Proving it for a restricted case, like a 2x2 matrix should be
straightforward, the challenge is to prove it for arbitrary matrices with size 2^n.
theory MatrixGen
use import int.Int
type mat 'a
function rows (mat 'a) : int
function cols (mat 'a) : int
axiom rows_and_cols_nonnegative:
forall m: mat 'a. 0 <= rows m /\ 0 <= cols m
function get (mat 'a) int int : 'a
function set (mat 'a) int int 'a : mat 'a
axiom set_def1:
forall m: mat 'a, i j: int, v: 'a. 0 <= i < rows m -> 0 <= j < cols m ->
rows (set m i j v) = rows m
axiom set_def2:
forall m: mat 'a, i j: int, v: 'a. 0 <= i < rows m -> 0 <= j < cols m ->
cols (set m i j v) = cols m
axiom set_def3:
forall m: mat 'a, i j: int, v: 'a. 0 <= i < rows m -> 0 <= j < cols m ->
get (set m i j v) i j = v
axiom set_def4:
forall m: mat 'a, i j: int, v: 'a. 0 <= i < rows m -> 0 <= j < cols m ->
forall i' j': int. (i <> i' \/ j <> j') ->
get (set m i j v) i' j' = get m i' j'
predicate (==) (m1 m2: mat 'a) =
rows m1 = rows m2 && cols m1 = cols m2 &&
forall i j: int. 0 <= i < rows m1 -> 0 <= j < cols m1 ->
get m1 i j = get m2 i j
axiom extensionality:
forall m1 m2: mat 'a. m1 == m2 -> m1 = m2
predicate (===) (a b: mat 'a) =
rows a = rows b /\ cols a = cols b
end
theory FloatMatrix
clone export MatrixGen
use HighOrd
use import int.Int
function create (r c: int) (f: int -> int -> 'a) : mat 'a
axiom create_rows:
forall r c: int, f: int -> int -> 'a.
0 <= r -> rows (create r c f) = r
axiom create_cols:
forall r c: int, f: int -> int -> 'a.
0 <= c -> cols (create r c f) = c
axiom create_get:
forall r c: int, f: int -> int -> 'a, i j: int.
0 <= i < r -> 0 <= j < c -> get (create r c f) i j = f i j
end
theory FixedMatrix
use import int.Int
constant r: int
constant c: int
type mat 'a
function rows (mat 'a) : int = r
function cols (mat 'a) : int = c
axiom r_and_c_nonnegative:
0 <= r /\ 0 <= c
clone export MatrixGen with
type mat 'a = mat 'a,
function rows = rows,
function cols = cols,
goal rows_and_cols_nonnegative
use HighOrd
function create (f: int -> int -> 'a) : mat 'a
axiom create_get:
forall f: int -> int -> 'a, i j: int.
0 <= i < r -> 0 <= j < c -> get (create f) i j = f i j
end
theory SquareFixedMatrix
use import int.Int
constant d: int
axiom dimension_nonnegative:
0 <= d
clone export FixedMatrix with
function r = d,
function c = d,
goal r_and_c_nonnegative
end
(* theory Square_Matrix = Matrix with axiom rows_and_cols_nonnegative:
forall m: mat 'a. 0 <= rows m /\ 0 <= cols m /\ rows m = cols m *)
module MatrixArithmetic
use import int.Int
use import int.Sum
use import sum_extended.Sum_extended
use import FloatMatrix
(* Zero matrix *)
constant zerof : int -> int -> int = \_ _. 0
function zero (r c: int) : mat int = create r c zerof
(* Matrix addition *)
function add2f (a b: mat int) : int -> int -> int =
\x y. get a x y + get b x y
function add (a b: mat int) : mat int =
create (rows a) (cols a) (add2f a b)
(* Matrix additive inverse *)
function opp2f (a: mat int) : int -> int -> int =
\x y. - get a x y
function opp (a: mat int) : mat int =
create (rows a) (cols a) (opp2f a)
function sub (a b: mat int) : mat int =
add a (opp b)
(* Matrix multiplication *)
function mul_atom (a b: mat int) (i j:int) : int -> int =
\k. get a i k * get b k j
function mul_cell (a b: mat int): int -> int -> int =
\i j. sum 0 (cols a) (mul_atom a b i j)
function mul (a b: mat int) : mat int =
create (rows a) (cols b) (mul_cell a b)
lemma zero_neutral:
forall a. add a (zero a.rows a.cols) = a
by add a (zero a.rows a.cols) == a
lemma add_commutative:
forall a b. a === b ->
add a b = add b a
by add a b == add b a
lemma add_associative:
forall a b c. a === b === c ->
add a (add b c) = add (add a b) c
by add a (add b c) == add (add a b) c
lemma add_opposite:
forall a. add a (opp a) = zero a.rows a.cols
by add a (opp a) == zero a.rows a.cols
lemma opp_involutive:
forall m. opp (opp m) = m
lemma opposite_add: forall m1 m2.
m1 === m2 -> opp (add m1 m2) = add (opp m1) (opp m2)
function ft1 (a b c: mat int) (i j: int) : int -> int -> int =
\k. smulf (get c k j) (mul_atom a b i k)
function ft2 (a b c: mat int) (i j: int) : int -> int -> int =
\k. smulf (get a i k) (mul_atom b c k j)
let lemma mul_assoc_get (a b c: mat int) (i j: int)
requires { cols a = rows b }
requires { cols b = rows c }
requires { 0 <= i < (rows a) /\ 0 <= j < (cols c) }
ensures { get (mul (mul a b) c) i j = get (mul a (mul b c)) i j }
= let ft1 = ft1 a b c i j in
let ft2 = ft2 a b c i j in
fubini ft1 ft2 0 (cols b) 0 (cols a);
sum_ext 0 (cols b) (mul_atom (mul a b) c i j) (sumf 0 (cols a) ft1);
assert { get (mul (mul a b) c) i j = sum 0 (cols b) (sumf 0 (cols a) ft1) };
sum_ext 0 (cols a) (mul_atom a (mul b c) i j) (sumf 0 (cols b) ft2);
assert { get (mul a (mul b c)) i j = sum 0 (cols a) (sumf 0 (cols b) ft2) }
lemma mul_assoc:
forall a b c. cols a = rows b -> cols b = rows c ->
let ab = mul a b in
let bc = mul b c in
let a_bc = mul a bc in
let ab_c = mul ab c in
a_bc = ab_c
by a_bc == ab_c
let lemma mul_distr_right_get (a b c: mat int) (i j: int)
requires { 0 <= i < rows a /\ 0 <= j < cols c }
requires { cols b = rows c}
requires { a === b }
ensures { get (mul (add a b) c) i j = get (add (mul a c) (mul b c)) i j }
= let mac = mul_atom a c i j in
let mbc = mul_atom b c i j in
assert { get (add (mul a c) (mul b c)) i j =
get (mul a c) i j + get (mul b c) i j =
sum 0 (cols b) (addf mac mbc) };
sum_ext 0 (cols b) (addf mac mbc) (mul_atom (add a b) c i j)
lemma mul_distr_right:
forall a b c. a === b -> cols b = rows c ->
mul (add a b) c = add (mul a c) (mul b c)
by mul (add a b) c == add (mul a c) (mul b c)
let lemma mul_distr_left_get (a b c: mat int) (i j : int)
requires { 0 <= i < rows a /\ 0 <= j < cols c }
requires { cols a = rows b }
requires { b === c }
ensures { get (mul a (add b c)) i j = get (add (mul a b) (mul a c)) i j }
= let mab = mul_atom a b i j in
let mac = mul_atom a c i j in
assert { get (add (mul a b) (mul a c)) i j =
get (mul a b) i j + get (mul a c) i j =
sum 0 (cols a) (addf mab mac) };
sum_ext 0 (cols a) (addf mab mac) (mul_atom a (add b c) i j)
lemma mul_distr_left:
forall a b c. b === c -> cols a = rows b ->
mul a (add b c) = add (mul a b) (mul a c)
by mul a (add b c) == add (mul a b) (mul a c)
lemma mul_zero_right:
forall a c. 0 <= c -> mul a (zero a.cols c) = zero a.rows c
lemma mul_zero_left:
forall a r. 0 <= r -> mul (zero r a.rows) a = zero r a.cols
lemma mul_opp:
forall a b. a.cols = b.rows ->
let oa = opp a in
let ob = opp b in
let ab = mul a b in
let oab = opp ab in
mul oa b = oab = mul a ob
by add (mul oa b) (add ab oab) = oab = add (mul a ob) (add ab oab)
end
module BlockMul
use import int.Int
use import int.Sum
use import sum_extended.Sum_extended
use import FloatMatrix
use import MatrixArithmetic
function ofs2 (a: mat int) (ai aj: int) : int -> int -> int
= \i j. get a (ai + i) (aj + j)
function block (a: mat int) (r dr c dc: int) : mat int =
create dr dc (ofs2 a r c)
predicate c_blocks (a a1 a2: mat int) =
0 <= a1.cols <= a.cols /\ a1 = block a 0 a.rows 0 a1.cols /\
a2 = block a 0 a.rows a1.cols (a.cols - a1.cols)
predicate r_blocks (a a1 a2: mat int) =
0 <= a1.rows <= a.rows /\ a1 = block a 0 a1.rows 0 a.cols /\
a2 = block a a1.rows (a.rows - a1.rows) 0 a.cols
let rec ghost block_mul_ij (a a1 a2 b b1 b2: mat int) (k: int) : unit
requires { a.cols = b.rows /\ a1.cols = b1.rows}
requires { 0 <= k <= a.cols }
requires { c_blocks a a1 a2 /\ r_blocks b b1 b2 }
ensures { forall i j. 0 <= i < a.rows -> 0 <= j < b.cols ->
0 <= k <= a1.cols ->
sum 0 k (mul_atom a b i j) = sum 0 k (mul_atom a1 b1 i j) }
ensures { forall i j. 0 <= i < a.rows -> 0 <= j < b.cols ->
a1.cols <= k <= a.cols ->
sum 0 k (mul_atom a b i j) =
sum 0 a1.cols (mul_atom a1 b1 i j) +
sum 0 (k - a1.cols) (mul_atom a2 b2 i j) }
variant { k }
= if 0 < k then begin
let k = k - 1 in
assert { forall i j. 0 <= i < a.rows -> 0 <= j < b.cols ->
mul_atom a b i j k = if k < a1.cols then mul_atom a1 b1 i j k
else mul_atom a2 b2 i j (k - a1.cols) };
block_mul_ij a a1 a2 b b1 b2 k
end
let lemma mul_split (a a1 a2 b b1 b2: mat int) : unit
requires { a.cols = b.rows /\ a1.cols = b1.rows}
requires { c_blocks a a1 a2 /\ r_blocks b b1 b2 }
ensures {add (mul a1 b1) (mul a2 b2) = mul a b }
= block_mul_ij a a1 a2 b b1 b2 a.cols;
assert { add (mul a1 b1) (mul a2 b2) == mul a b }
let lemma mul_block_cell (a b: mat int) (r dr c dc i j: int) : unit
requires { a.cols = b.rows }
requires { 0 <= r /\ r + dr <= a.rows }
requires { 0 <= c /\ c + dc <= b.cols }
requires { 0 <= i < dr /\ 0 <= j < dc }
ensures { ofs2 (mul a b) r c i j =
get (mul (block a r dr 0 a.cols) (block b 0 b.rows c dc)) i j }
= let a' = block a r dr 0 a.cols in
let b' = block b 0 b.rows c dc in
sum_ext 0 a.cols (mul_atom a b (i + r) (j + c)) (mul_atom a' b' i j)
lemma mul_block:
forall a b: mat int, r dr c dc: int.
a.cols = b.rows -> 0 <= r <= r + dr <= a.rows ->
0 <= c <= c + dc <= b.cols ->
let a' = block a r dr 0 a.cols in
let b' = block b 0 b.rows c dc in
let m' = block (mul a b) r dr c dc in
m' = mul a' b'
by m' == mul a' b'
end
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE why3session PUBLIC "-//Why3//proof session v5//EN"
"http://why3.lri.fr/why3session.dtd">
<why3session shape_version="4">
<prover id="0" name="CVC4" version="1.4" timelimit="5" steplimit="0" memlimit="1000"/>
<prover id="1" name="Alt-Ergo" version="1.01" timelimit="5" steplimit="0" memlimit="1000"/>
<file name="../matrices.mlw" expanded="true">
<theory name="MatrixGen" sum="d41d8cd98f00b204e9800998ecf8427e" expanded="true">
</theory>
<theory name="FloatMatrix" sum="d41d8cd98f00b204e9800998ecf8427e" expanded="true">
</theory>
<theory name="FixedMatrix" sum="81d1f6ecddc981db1f906002ad2c6515">
<goal name="rows_and_cols_nonnegative">
<proof prover="1"><result status="valid" time="0.00" steps="2"/></proof>
</goal>
</theory>
<theory name="SquareFixedMatrix" sum="8866fe956f3cb17788b7229e8763a520">
<goal name="r_and_c_nonnegative">
<proof prover="1"><result status="valid" time="0.00" steps="1"/></proof>
</goal>
</theory>
<theory name="MatrixArithmetic" sum="95c5ea5a8476dec9b6807e6463fa7c50">
<goal name="zero_neutral">
<transf name="split_goal_wp">
<goal name="zero_neutral.1" expl="1.">
<proof prover="1"><result status="valid" time="0.02" steps="30"/></proof>
</goal>
<goal name="zero_neutral.2" expl="2.">
<proof prover="1"><result status="valid" time="0.02" steps="9"/></proof>
</goal>
</transf>
</goal>
<goal name="add_commutative">
<transf name="split_goal_wp">
<goal name="add_commutative.1" expl="1.">
<proof prover="1"><result status="valid" time="0.03" steps="46"/></proof>
</goal>
<goal name="add_commutative.2" expl="2.">
<proof prover="1"><result status="valid" time="0.01" steps="17"/></proof>
</goal>
</transf>
</goal>
<goal name="add_associative">
<transf name="split_goal_wp">
<goal name="add_associative.1" expl="1.">
<proof prover="1"><result status="valid" time="0.14" steps="301"/></proof>
</goal>
<goal name="add_associative.2" expl="2.">
<proof prover="1"><result status="valid" time="0.01" steps="24"/></proof>
</goal>
</transf>
</goal>
<goal name="add_opposite">
<transf name="split_goal_wp">
<goal name="add_opposite.1" expl="1.">
<proof prover="1"><result status="valid" time="0.12" steps="153"/></proof>
</goal>
<goal name="add_opposite.2" expl="2.">
<proof prover="1"><result status="valid" time="0.01" steps="14"/></proof>
</goal>
</transf>
</goal>
<goal name="opp_involutive">
<proof prover="0"><result status="valid" time="0.07"/></proof>
</goal>
<goal name="opposite_add">
<proof prover="0"><result status="valid" time="0.25"/></proof>
</goal>
<goal name="WP_parameter mul_assoc_get" expl="VC for mul_assoc_get">
<transf name="split_goal_wp">
<goal name="WP_parameter mul_assoc_get.1" expl="1. precondition">
<proof prover="1"><result status="valid" time="0.02" steps="21"/></proof>
</goal>
<goal name="WP_parameter mul_assoc_get.2" expl="2. precondition">
<proof prover="1"><result status="valid" time="2.40" steps="94"/></proof>
</goal>
<goal name="WP_parameter mul_assoc_get.3" expl="3. assertion">
<proof prover="1"><result status="valid" time="0.06" steps="37"/></proof>
</goal>
<goal name="WP_parameter mul_assoc_get.4" expl="4. precondition">
<proof prover="1"><result status="valid" time="0.05" steps="35"/></proof>
</goal>
<goal name="WP_parameter mul_assoc_get.5" expl="5. assertion">
<proof prover="1"><result status="valid" time="0.04" steps="34"/></proof>
</goal>
<goal name="WP_parameter mul_assoc_get.6" expl="6. postcondition">
<proof prover="1"><result status="valid" time="0.02" steps="11"/></proof>
</goal>
</transf>
</goal>
<goal name="mul_assoc">
<transf name="split_goal_wp">
<goal name="mul_assoc.1" expl="1.">
<proof prover="1"><result status="valid" time="0.02" steps="60"/></proof>
</goal>
<goal name="mul_assoc.2" expl="2.">
<proof prover="1"><result status="valid" time="0.01" steps="18"/></proof>
</goal>
</transf>
</goal>
<goal name="WP_parameter mul_distr_right_get" expl="VC for mul_distr_right_get">
<transf name="split_goal_wp">
<goal name="WP_parameter mul_distr_right_get.1" expl="1. assertion">
<proof prover="1"><result status="valid" time="0.06" steps="106"/></proof>
</goal>
<goal name="WP_parameter mul_distr_right_get.2" expl="2. precondition">
<proof prover="1"><result status="valid" time="0.10" steps="53"/></proof>
</goal>
<goal name="WP_parameter mul_distr_right_get.3" expl="3. postcondition">
<proof prover="1"><result status="valid" time="0.06" steps="61"/></proof>
</goal>
</transf>
</goal>
<goal name="mul_distr_right">
<transf name="split_goal_wp">
<goal name="mul_distr_right.1" expl="1.">
<proof prover="1"><result status="valid" time="0.05" steps="146"/></proof>
</goal>
<goal name="mul_distr_right.2" expl="2.">
<proof prover="1"><result status="valid" time="0.01" steps="22"/></proof>
</goal>
</transf>
</goal>
<goal name="WP_parameter mul_distr_left_get" expl="VC for mul_distr_left_get">
<transf name="split_goal_wp">
<goal name="WP_parameter mul_distr_left_get.1" expl="1. assertion">
<proof prover="1"><result status="valid" time="0.06" steps="110"/></proof>
</goal>
<goal name="WP_parameter mul_distr_left_get.2" expl="2. precondition">
<proof prover="1"><result status="valid" time="0.05" steps="40"/></proof>
</goal>
<goal name="WP_parameter mul_distr_left_get.3" expl="3. postcondition">
<proof prover="1"><result status="valid" time="0.05" steps="45"/></proof>
</goal>
</transf>
</goal>
<goal name="mul_distr_left">
<transf name="split_goal_wp">
<goal name="mul_distr_left.1" expl="1.">
<proof prover="1"><result status="valid" time="0.04" steps="144"/></proof>
</goal>
<goal name="mul_distr_left.2" expl="2.">
<proof prover="1"><result status="valid" time="0.01" steps="22"/></proof>
</goal>
</transf>
</goal>
<goal name="mul_zero_right">
<proof prover="1"><result status="valid" time="1.91" steps="404"/></proof>
</goal>
<goal name="mul_zero_left">
<proof prover="1"><result status="valid" time="1.65" steps="396"/></proof>
</goal>
<goal name="mul_opp">
<transf name="split_goal_wp">
<goal name="mul_opp.1" expl="1.">
<proof prover="1"><result status="valid" time="0.36" steps="394"/></proof>
</goal>
<goal name="mul_opp.2" expl="2.">
<proof prover="0"><result status="valid" time="0.08"/></proof>
</goal>
<goal name="mul_opp.3" expl="3.">
<proof prover="0"><result status="valid" time="0.04"/></proof>
</goal>
<goal name="mul_opp.4" expl="4.">
<proof prover="1"><result status="valid" time="3.92" steps="2401"/></proof>
</goal>
</transf>
</goal>
</theory>
<theory name="BlockMul" sum="08ae5cbffa4e68e45d31bc65f734c776">
<goal name="WP_parameter block_mul_ij" expl="VC for block_mul_ij">
<proof prover="1"><result status="valid" time="2.31" steps="880"/></proof>
</goal>
<goal name="WP_parameter mul_split" expl="VC for mul_split">
<proof prover="1"><result status="valid" time="0.29" steps="218"/></proof>
</goal>
<goal name="WP_parameter mul_block_cell" expl="VC for mul_block_cell">
<transf name="split_goal_wp">
<goal name="WP_parameter mul_block_cell.1" expl="1. precondition">
<proof prover="1"><result status="valid" time="0.03" steps="27"/></proof>
</goal>
<goal name="WP_parameter mul_block_cell.2" expl="2. postcondition">
<proof prover="1"><result status="valid" time="0.12" steps="120"/></proof>
</goal>
</transf>
</goal>
<goal name="mul_block">
<transf name="split_goal_wp">
<goal name="mul_block.1" expl="1.">
<proof prover="1"><result status="valid" time="0.04" steps="65"/></proof>
</goal>
<goal name="mul_block.2" expl="2.">
<proof prover="1"><result status="valid" time="0.02" steps="24"/></proof>
</goal>
</transf>
</goal>
</theory>
</file>
</why3session>