Commit 1dbdd078 authored by Martin Clochard's avatar Martin Clochard

Example verifythis_2016_matrix_multiplication: Strassen's algorithm adapted

to handle non-square matrices
parent b00248b5
......@@ -241,7 +241,7 @@ module MatrixMultiplication
= assert { block (block a.mdl r1 dr1 c1 dc1) r2 dr2 c2 dc2 ==
block a.mdl (r1+r2) dr2 (c1+c2) dc2 }
let padding (a: matrix int) (r c: int)
let padding (a: matrix int) (r c: int) : matrix int
requires { a.mdl.F.rows <= r }
requires { a.mdl.F.cols <= c }
ensures { result.mdl.F.rows = r }
......@@ -265,44 +265,54 @@ module MatrixMultiplication
use import int.EuclideanDivision
let rec strassen (a b: matrix int) (ghost flag: int)
requires { 0 <= flag }
requires { flag = 0 -> a.mdl.F.cols = 1 \/
exists k. a.mdl.F.cols = 2 * k }
requires { a.mdl === b.mdl /\ a.mdl.F.cols = b.mdl.F.rows }
ensures { result.mdl = mul a.mdl b.mdl }
ensures { result.mdl === a.mdl }
variant { a.mdl.F.cols + flag, flag }
let rec strassen (a b: matrix int) (ghost flag:int) : matrix int
requires { a.mdl.F.cols = b.mdl.F.rows }
requires { flag >= 0 }
requires { flag = 0 ->
a.mdl.F.rows = 1 \/ a.mdl.F.cols = 1 \/ b.mdl.F.cols = 1 \/
exists k l m. a.mdl.F.rows = 2*k
/\ a.mdl.F.cols = 2*l
/\ b.mdl.F.cols = 2*m }
ensures { result.mdl = mul a.mdl b.mdl }
ensures { result.mdl.F.rows = a.mdl.F.rows }
ensures { result.mdl.F.cols = b.mdl.F.cols }
variant { a.mdl.F.rows + a.mdl.F.cols + b.mdl.F.cols + 3 * flag, flag }
= let cut_off = abstract ensures { result >= 1 } 42 end in
let s = a.columns in
assert { 0 <= s = a.mdl.F.cols };
if s <= cut_off then mul_naive a b else
let (n,r) = abstract
ensures { let (n,r) = result in s = 2 * n + r /\ 0 <= r <= 1 }
(div s 2,mod s 2)
end in
if r <> 0
then begin (* padding *)
let s' = s + 1 in
let ap = padding a s' s' in
let bp = padding b s' s' in
assert { s' = 2 * (n + 1) };
let m = strassen ap bp 0 in
ghost (double_block ap 0 s 0 s' 0 s 0 s;
double_block ap 0 s 0 s' 0 s s 1;
double_block bp 0 s' 0 s 0 s 0 s;
double_block bp 0 s' 0 s s 1 0 s);
assert { c_blocks (block ap.mdl 0 s 0 s') a.mdl (zero s 1) };
assert { r_blocks (block bp.mdl 0 s' 0 s ) b.mdl (zero 1 s) };
block m 0 s 0 s
end
else begin
let rw = a.rows in
let md = a.columns in
let cl = b.columns in
assert { rw = a.mdl.F.rows /\ md = a.mdl.F.cols /\ cl = b.mdl.F.cols };
if rw <= cut_off || md <= cut_off || cl <= cut_off
then mul_naive a b else
let div2 (n: int) : (int,int)
requires { 0 <= n }
returns { (q,r) -> n = 2 * q + r /\ 0 <= r <= 1 /\ n + r = 2 * (q+r) }
= (div n 2,mod n 2) in
let (qr,rr) = div2 rw in
let (qm,rm) = div2 md in
let (qc,rc) = div2 cl in
if rr <> 0 || rm <> 0 || rc <> 0
then begin (* Padding *)
let rw' = rw + rr in
let md' = md + rm in
let cl' = cl + rc in
let ap = padding a rw' md' in
let bp = padding b md' cl' in
let m = strassen ap bp 0 in
ghost (double_block ap 0 rw 0 md' 0 rw 0 md;
double_block ap 0 rw 0 md' 0 rw md rm;
double_block bp 0 md' 0 cl 0 md 0 cl;
double_block bp 0 md' 0 cl md rm 0 cl);
assert { c_blocks (block ap.mdl 0 rw 0 md') a.mdl (zero rw rm) };
assert { r_blocks (block bp.mdl 0 md' 0 cl ) b.mdl (zero rm cl) };
block m 0 rw 0 cl
end else begin
(* Regular Strassen multiplication *)
let ghost gm = mul_naive a b in
let ghost gm11 = block gm 0 n 0 n in
let ghost gm12 = block gm 0 n n n in
let ghost gm21 = block gm n n 0 n in
let ghost gm22 = block gm n n n n in
let ghost gm11 = block gm 0 qr 0 qc in
let ghost gm12 = block gm 0 qr qc qc in
let ghost gm21 = block gm qr qr 0 qc in
let ghost gm22 = block gm qr qr qc qc in
let (m11, m12, m21, m22) = abstract
ensures { let (m11, m12, m21, m22) = result in
mdl m11 = mdl gm11 /\ mdl m12 = mdl gm12 /\
......@@ -310,27 +320,35 @@ module MatrixMultiplication
let ghost e = symb_env () in
let mul_ws (a b: with_symb) : with_symb
requires { with_symb_vld e a /\ with_symb_vld e b }
requires { a.phy.mdl === b.phy.mdl === gm11.mdl }
requires { a.phy.mdl.F.rows = qr /\ a.phy.mdl.F.cols = qm }
requires { b.phy.mdl.F.rows = qm /\ b.phy.mdl.F.cols = qc }
ensures { with_symb_vld e result }
ensures { result.phy.mdl = mul a.phy.mdl b.phy.mdl }
ensures { result.sym ---> symb_mul a.sym b.sym }
= let r = strassen a.phy b.phy (if n = 1 then 0 else 1) in
= let ghost flag = if qr = 1 || qm = 1 || qc = 1 then 0 else 1 in
let r = strassen a.phy b.phy flag in
{ phy = r; sym = ghost symb_mul e a.sym b.sym }
in
(* a blocks *)
let a11 = block_ws e a 0 n 0 n in let a12 = block_ws e a 0 n n n in
let a21 = block_ws e a n n 0 n in let a22 = block_ws e a n n n n in
ghost (double_block a 0 n 0 s 0 n 0 n; double_block a 0 n 0 s 0 n n n;
double_block a n n 0 s 0 n 0 n; double_block a n n 0 s 0 n n n);
assert { c_blocks (block a.mdl 0 n 0 s) a11.phy.mdl a12.phy.mdl };
assert { c_blocks (block a.mdl n n 0 s) a21.phy.mdl a22.phy.mdl };
let a11 = block_ws e a 0 qr 0 qm in
let a12 = block_ws e a 0 qr qm qm in
let a21 = block_ws e a qr qr 0 qm in
let a22 = block_ws e a qr qr qm qm in
ghost (double_block a 0 qr 0 md 0 qr 0 qm;
double_block a 0 qr 0 md 0 qr qm qm;
double_block a qr qr 0 md 0 qr 0 qm;
double_block a qr qr 0 md 0 qr qm qm);
assert { c_blocks (block a.mdl 0 qr 0 md) a11.phy.mdl a12.phy.mdl };
assert { c_blocks (block a.mdl qr qr 0 md) a21.phy.mdl a22.phy.mdl };
(* b blocks *)
let b11 = block_ws e b 0 n 0 n in let b12 = block_ws e b 0 n n n in
let b21 = block_ws e b n n 0 n in let b22 = block_ws e b n n n n in
ghost (double_block b 0 s 0 n 0 n 0 n; double_block b 0 s 0 n n n 0 n;
double_block b 0 s n n 0 n 0 n; double_block b 0 s n n n n 0 n);
assert { r_blocks (block b.mdl 0 s 0 n) b11.phy.mdl b21.phy.mdl };
assert { r_blocks (block b.mdl 0 s n n) b12.phy.mdl b22.phy.mdl };
let b11 = block_ws e b 0 qm 0 qc in let b12 = block_ws e b 0 qm qc qc in
let b21 = block_ws e b qm qm 0 qc in let b22 = block_ws e b qm qm qc qc in
ghost (double_block b 0 md 0 qc 0 qm 0 qc;
double_block b 0 md 0 qc qm qm 0 qc;
double_block b 0 md qc qc 0 qm 0 qc;
double_block b 0 md qc qc qm qm 0 qc);
assert { r_blocks (block b.mdl 0 md 0 qc) b11.phy.mdl b21.phy.mdl };
assert { r_blocks (block b.mdl 0 md qc qc) b12.phy.mdl b22.phy.mdl };
let ghost egm11 = (add_ws e (mul_ws a11 b11) (mul_ws a12 b21)).sym in
let ghost egm21 = (add_ws e (mul_ws a21 b11) (mul_ws a22 b21)).sym in
let ghost egm12 = (add_ws e (mul_ws a11 b12) (mul_ws a12 b22)).sym in
......@@ -357,21 +375,20 @@ module MatrixMultiplication
(m11.phy, m12.phy, m21.phy, m22.phy)
end in
let res = make a.rows b.columns 0 in
blit m11 res 0 0 n 0 0 n;
blit m12 res 0 0 n 0 n n;
blit m21 res 0 n n 0 0 n;
blit m22 res 0 n n 0 n n;
blit m11 res 0 0 qr 0 0 qc;
blit m12 res 0 0 qr 0 qc qc;
blit m21 res 0 qr qr 0 0 qc;
blit m22 res 0 qr qr 0 qc qc;
assert { res.mdl == gm.mdl by
forall i j. 0 <= i < s -> 0 <= j < s ->
forall i j. 0 <= i < rw -> 0 <= j < cl ->
res.elts[i][j] = F.get gm.mdl i j by
if i < n then
if j < n then res.elts[i][j] = F.get m11.mdl i j
else res.elts[i][j] = F.get m12.mdl i (j - n)
if i < qr then
if j < qc then res.elts[i][j] = F.get m11.mdl i j
else res.elts[i][j] = F.get m12.mdl i (j - qc)
else
if j < n then res.elts[i][j] = F.get m21.mdl (i-n) j
else res.elts[i][j] = F.get m22.mdl (i-n) (j - n) };
if j < qc then res.elts[i][j] = F.get m21.mdl (i - qr) j
else res.elts[i][j] = F.get m22.mdl (i - qr) (j - qc) };
res
end
end
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment