Commit f3b04e9c authored by Raphael Rieu-Helft's avatar Raphael Rieu-Helft
Browse files

preuve de strassen par réflexion

parent ab3922d4
......@@ -574,8 +574,30 @@ predicate (===) (m1 m2: mat F.t) =
forall i j: int. 0 <= i -> 0 <= j ->
row_zeros m1 i = row_zeros m2 i /\ col_zeros m1 j = col_zeros m2 j
predicate in_bounds (m: mat F.t) (i j: int) =
0 <= i < col_zeros m j /\ 0 <= j < row_zeros m i
lemma oob_zero:
forall m: mat F.t, i j: int. 0 <= i -> 0 <= j -> not in_bounds m i j
-> get m i j = F.zero
predicate size (m: mat F.t) (r c: int) =
forall i j. 0 <= i -> 0 <= j -> row_zeros m i = c /\ col_zeros m j = r
(forall i: int. 0 <= i -> row_zeros m i = c)
/\ (forall j: int. 0 <= j -> col_zeros m j = r)
lemma size_to_bounds:
forall m: mat F.t, r c i j: int. size m r c -> (in_bounds m i j <-> (0 <= i < r /\ 0 <= j < c))
lemma iso_size:
forall a b: mat F.t, r c: int. a === b -> (size a r c <-> size b r c)
lemma size_rows_ib:
forall a: mat F.t, r c i: int. size a r c ->
0 <= i < r -> row_zeros a i = c
by forall j: int. in_bounds a i j -> 0 <= j < c
lemma size_iso:
forall a b: mat F.t, r c: int. size a r c -> size b r c -> a === b
end
......@@ -617,17 +639,6 @@ module InfMatrix
0 <= i -> 0 <= j -> (i >= cz j \/ j >= rz i) ->
get (create rz cz f) i j = tzero
function fcreate (r c: int) (f: int -> int -> t) : mat =
create (fun _ -> c) (fun _ -> r) f
lemma fcreate_get_ib:
forall r c i j: int, f: int -> int -> t.
0 <= i < r -> 0 <= j < c -> get (fcreate r c f) i j = f i j
lemma fcreate_get_oob:
forall r c i j: int, f: int -> int -> t.
0 <= i -> 0 <= j -> (i >= r \/ j >= c) -> get (fcreate r c f) i j = tzero
function set (m: mat) (i j:int) (v:t) : mat =
create
......@@ -635,7 +646,6 @@ module InfMatrix
(fun j1 -> if j1 = j then max (i+1) (col_zeros m j) else col_zeros m j1)
(fun i1 j1 -> if i1 = i && j1 = j then v else get m i1 j1)
clone export InfMatrixGen with type mat 'a = mat,
type F.t = t,
function get = get,
......@@ -650,7 +660,24 @@ module InfMatrix
lemma set_def_rowz_unchanged,
lemma set_def_other_rowz,
lemma set_def_other_colz
(*
lemma create_ib:
forall rz cz: int -> int, f: int -> int -> t, i j: int.
in_bounds (create rz cz f) i j -> get (create rz cz f) i j = f i j
*)
function fcreate (r c: int) (f: int -> int -> t) : mat =
create (fun _ -> c) (fun _ -> r) f
(*create (fun i -> if 0 <= i < r then c else 0)
(fun j -> if 0 <= j < c then r else 0)
f*)
lemma fcreate_get_ib:
forall r c i j: int, f: int -> int -> t.
0 <= i < r -> 0 <= j < c -> get (fcreate r c f) i j = f i j
lemma fcreate_get_oob:
forall r c i j: int, f: int -> int -> t.
0 <= i -> 0 <= j -> (i >= r \/ j >= c) -> get (fcreate r c f) i j = tzero
lemma fcreate_size:
forall r c: int, f: int -> int -> t. size (fcreate r c f) r c
......@@ -710,6 +737,13 @@ theory MaxFun
= max (maxf f a b) (maxf f b c) }
end
let rec lemma max_constant (f: int -> int) (v a b: int)
requires { v >= 0 }
requires { a < b }
requires { forall i. a <= i < b -> f i = v }
ensures { maxf f a b = v }
variant { b-a }
= if a = b-1 then () else max_constant f v a (b-1)
end
......@@ -724,7 +758,7 @@ module InfIntMatrix
constant zerof : int -> int -> int = fun _ _ -> 0
constant zero : mat = fcreate 0 0 zerof
constant mzero : mat = fcreate 0 0 zerof
function zerorc (r c: int) : mat = fcreate r c zerof
......@@ -753,11 +787,21 @@ module InfIntMatrix
forall a b: mat, i j: int. 0 <= i -> 0 <= j ->
get (add a b) i j = get a i j + get b i j
lemma add_sizes:
lemma add_iso:
forall a b: mat. a === b -> a === add a b === b
lemma add_size:
forall a b: mat, r c: int. size a r c -> size b r c -> size (add a b) r c
by (forall i j:int. (in_bounds a i j \/ in_bounds b i j) <-> in_bounds (add a b) i j)
lemma add_assoc:
forall a b c: mat. add (add a b) c = add a (add b c) by add (add a b) c == add a (add b c)
lemma add_commutative:
forall a b: mat. add a b = add b a by add a b == add b a
lemma zero_neutral:
forall a. add a zero = a by add a zero == a
forall a. add a mzero = a by add a mzero == a
(* Matrix additive inverse *)
function opp2f (a: mat) : int -> int -> int =
......@@ -783,7 +827,7 @@ module InfIntMatrix
else k >= col_zeros b j so get b k j = 0
function mul_cell_bound (a b: mat) (i j: int) : int
= min (row_zeros a i) (col_zeros b j)
= min (row_zeros a i) (col_zeros b j) (* row_zeros a i*)
function mul_cell (a b: mat) : int -> int -> int =
fun i j -> sum (mul_atom a b i j) 0 (mul_cell_bound a b i j)
......@@ -815,66 +859,310 @@ module InfIntMatrix
(fun j -> maxf (fun k -> col_zeros a k) 0 (col_zeros b j))
(mul_cell a b)
lemma mul_sizes:
forall m1 m2: mat, m n p: int.
size m1 m n -> size m2 n p -> size (mul m1 m2) m p
let lemma mul_sizes (m1 m2: mat) (m n p: int)
requires { size m1 m n /\ size m2 n p }
requires { 0 < m /\ 0 < n /\ 0 < p }
ensures { size (mul m1 m2) m p }
=
let r = mul m1 m2 in
max_constant (fun k -> row_zeros m2 k) p 0 n;
assert { forall i. 0 <= i -> row_zeros r i = p };
max_constant (fun k -> col_zeros m1 k) m 0 n;
assert { forall j. 0 <= j -> col_zeros r j = m }
lemma id_neutral_r:
forall m: mat. mul m id = m by mul m id == m
forall m: mat. mul m id = m
by (mul m id == m
by (forall i j. in_bounds m i j -> mul_cell m id i j = get m i j)
/\ (forall i j. 0 <= i -> 0 <= j -> not (in_bounds m i j)
-> mul_cell m id i j = 0 = get m i j))
lemma id_neutral_l:
forall m: mat. mul id m = m by mul id m == m
forall m: mat. mul id m = m
by (mul id m == m
by (forall i j. in_bounds m i j ->
mul_cell id m i j = get m i j
by let t = mul_cell_bound id m i j in
sum (mul_atom id m i j) 0 i = 0
so sum (mul_atom id m i j) (i+1) t = 0
so mul_cell id m i j
= sum (mul_atom id m i j) 0 t
= sum (mul_atom id m i j) 0 i + mul_atom id m i j i + sum (mul_atom id m i j) (i+1) t
= 0 + (get id i i)*(get m i j) + 0
= get m i j)
/\ (forall i j. 0 <= i -> 0 <= j -> not (in_bounds m i j)
-> mul_cell id m i j = 0 = get m i j))
use import sum_extended.Sum_extended (* examples/verify_this_2016/matrix_multiplication *)
function ft1 (a b c: mat) (i j: int) : int -> int -> int =
fun k -> smulf (mul_atom a b i k) (get c k j)
function ft2 (a b c: mat) (i j: int) : int -> int -> int =
fun k -> smulf (mul_atom b c k j) (get a i k)
let lemma mul_assoc_get (a b c: mat) (i j: int)
(* requires { 0 <= i < row_zeros a j /\ 0 <= j < col_zeros c i } *)
requires { 0 <= i /\ 0 <= j }
ensures { get (mul (mul a b) c) i j = get (mul a (mul b c)) i j }
= let ft1 = ft1 a b c i j in
let ft2 = ft2 a b c i j in
fubini ft1 ft2 0 (col_zeros b i) 0 (col_zeros a i);
sum_ext (mul_atom (mul a b) c i j) (sumf ft1 0 (col_zeros a i)) 0 (col_zeros b i);
assert { get (mul (mul a b) c) i j = sum (sumf ft1 0 (col_zeros a i)) 0 (col_zeros b i) };
sum_ext (mul_atom a (mul b c) i j) (sumf ft2 0 (col_zeros b i)) 0 (col_zeros a i);
assert { get (mul a (mul b c)) i j = sum (sumf ft2 0 (col_zeros b i)) 0 (col_zeros a i) }
lemma mul_assoc:
forall a b c: mat.
let ab = mul a b in
let bc = mul b c in
let a_bc = mul a bc in
let ab_c = mul ab c in
a_bc = ab_c
by a_bc == ab_c
(* External product *)
function extf (c: int) (a: mat): int -> int -> int =
fun x y -> c * (get a x y)
end
function extp (c: int) (a: mat) : mat =
create (fun i -> row_zeros a i) (fun j -> col_zeros a j) (extf c a)
lemma ext_iso:
forall m: mat, r: int. extp r m === m
lemma ext_get:
forall m: mat, r i j: int. 0 <= i -> 0 <= j ->
get (extp r m) i j = r * (get m i j)
(*
use import int.MinMax
lemma ext_dist_sum_mat:
forall x y: mat, r: int. extp r (add x y) = add (extp r x) (extp r y)
by extp r (add x y) == add (extp r x) (extp r y)
use import matrices.Matrix
use import matrices.MatrixArithmetic
use import matrices.BlockMul
lemma ext_dist_sum_r:
forall x: mat, r s: int. extp (r+s) x = add (extp r x) (extp s x)
by extp (r+s) x == add (extp r x) (extp s x)
lemma assoc_mul_ext:
forall x: mat, r s: int. extp (r*s) x = extp r (extp s x)
by extp (r*s) x == extp r (extp s x)
lemma unit_ext:
forall x: mat. extp 1 x = x by extp 1 x == x
constant d : int
axiom DimNonNeg : d >= 0
lemma comm_mul_ext:
forall x y: mat, r: int. extp r (mul x y) = mul (extp r x) y = mul x (extp r y)
by extp r (mul x y) == mul (extp r x) y == mul x (extp r y)
function extf (c: int) (a:mat int) : int -> int -> int =
fun x y -> c * (get a x y)
end
function ext (c: int) (a:mat int) : mat int =
create (rows a) (cols a) (extf c a)
module InfIntMatrixDecision
function addm (a b: mat int) : mat int =
create (max a.rows b.rows) (max a.cols b.cols) (add2f a b)
use import InfIntMatrix
use import int.Int
constant zero11 : mat int = zero 0 0
constant one11 : mat int = create d d (fun x y -> if x=y then 1 else 0)
let predicate eq0_int (x:int) = x = 0
let predicate eq0_int (x:int) = x=0
clone export AssocAlgebraDecision with type r = int, type a = mat, constant R.zero = Int.zero, constant R.one = Int.one, function R.(+) = (+), function R.(-_) = (-_), function R.(*) = (*),constant A.zero = mzero, constant one = id, function (+) = add, function A.(-_) = opp, function ( *) = mul, function ($) = extp, goal AUnitary, goal ANonTrivial, goal ExtDistSumA, goal ExtDistSumR, goal AssocMulExt, goal UnitExt, goal CommMulExt, val eq0 = eq0_int, goal A.MulAssoc.Assoc, goal A.Unit_def_l, goal A.Unit_def_r, goal A.Comm, goal A.Assoc, goal A.Mul_distr_l, goal A.Mul_distr_r
clone export AssocAlgebraDecision with type r = int, type a = mat int, constant R.zero = Int.zero, constant R.one = Int.one, function R.(+) = (+), function R.(-_) = (-_), function R.(*) = (*),constant A.zero = zero11, constant one = one11, function (+) = add, function A.(-_) = opp, function (*) = mul, function ($) = ext, goal AUnitary, goal ANonTrivial, goal ExtDistSumA, goal ExtDistSumR, goal AssocMulExt, goal UnitExt, goal CommMulExt, val eq0 = eq0_int, goal A.MulAssoc.Assoc, goal A.Unit_def_l, goal A.Unit_def_r, goal A.Comm, goal A.Assoc
predicate quarters (a a11 a12 a21 a22: mat int) =
(rows a11 = rows a12 = rows a21 = rows a22 = cols a11 = cols a12 = cols a21 = cols a22) /\
rows a = cols a = 2 * rows a11 /\
a11 = block a 0 a11.rows 0 a11.cols /\ a12 = block a 0 a11.rows a11.cols a11.cols /\
a21 = block a a11.rows a11.rows 0 a11.cols /\ a22 = block a a11.rows a11.rows a11.cols a11.cols
let lemma naive_blocks (a b c a11 a12 a21 a22 b11 b12 b21 b22 c11 c12 c21 c22: mat int)
requires { quarters a a11 a12 a21 a22 }
requires { quarters b b11 b12 b21 b22 }
requires { quarters c c11 c12 c21 c22 }
requires { c11 = add (mul a11 b11) (mul b11 b22) }
requires { c12 = add (mul a11 b12) (mul a12 b22) }
requires { c21 = add (mul a21 b11) (mul a22 b21) }
requires { c22 = add (mul a21 b12) (mul a22 b22) }
ensures { c = mul a b }
=
()
end
module MatrixTests
use import InfIntMatrixDecision
use import InfIntMatrix
use import int.Int
use import int.Sum
use import sum_extended.Sum_extended
function cols (a: mat) : int (* if matrix is a finite rectangle, return number of cols *)
function rows (a: mat) : int
axiom rows_def:
forall a: mat, r c: int. size a r c -> rows a = r
axiom cols_def:
forall a: mat, r c: int. size a r c -> cols a = c
predicate is_finite (m: mat) = size m m.rows m.cols
function ofs2 (a: mat) (ai aj: int) : int -> int -> int
= fun i j -> get a (ai + i) (aj + j)
function block (a: mat) (r dr c dc: int) : mat =
fcreate dr dc (ofs2 a r c)
predicate c_blocks (a a1 a2: mat) =
0 <= a1.cols <= a.cols /\ a1 = block a 0 a.rows 0 a1.cols /\
a2 = block a 0 a.rows a1.cols (a.cols - a1.cols)
predicate r_blocks (a a1 a2: mat) =
0 <= a1.rows <= a.rows /\ a1 = block a 0 a1.rows 0 a.cols /\
a2 = block a a1.rows (a.rows - a1.rows) 0 a.cols
let rec lemma block_mul_ij (a a1 a2 b b1 b2: mat) (k: int)
requires { a.cols = b.rows /\ a1.cols = b1.rows}
requires { 0 <= k <= a.cols }
requires { c_blocks a a1 a2 /\ r_blocks b b1 b2 }
ensures { forall i j. 0 <= i < a.rows -> 0 <= j < b.cols ->
0 <= k <= a1.cols ->
sum (mul_atom a b i j) 0 k = sum (mul_atom a1 b1 i j) 0 k }
ensures { forall i j. 0 <= i < a.rows -> 0 <= j < b.cols ->
a1.cols <= k <= a.cols ->
sum (mul_atom a b i j) 0 k =
sum (mul_atom a1 b1 i j) 0 a1.cols +
sum (mul_atom a2 b2 i j) 0 (k - a1.cols) }
variant { k }
= if 0 < k then begin
let k = k - 1 in
assert { forall i j. 0 <= i < a.rows -> 0 <= j < b.cols ->
mul_atom a b i j k = if k < a1.cols then mul_atom a1 b1 i j k
else mul_atom a2 b2 i j (k - a1.cols) };
block_mul_ij a a1 a2 b b1 b2 k
end
*)
\ No newline at end of file
let lemma mul_split (a a1 a2 b b1 b2: mat) : unit
requires { a.cols = b.rows /\ a1.cols = b1.rows}
requires { c_blocks a a1 a2 /\ r_blocks b b1 b2 }
ensures {add (mul a1 b1) (mul a2 b2) = mul a b }
= block_mul_ij a a1 a2 b b1 b2 a.cols;
assert { add (mul a1 b1) (mul a2 b2) == mul a b }
let lemma mul_block_cell (a b: mat) (r dr c dc i j: int) : unit
requires { is_finite a /\ is_finite b }
requires { a.cols = b.rows }
requires { 0 <= r /\ r + dr <= a.rows }
requires { 0 <= c /\ c + dc <= b.cols }
requires { 0 <= i < dr /\ 0 <= j < dc }
ensures { ofs2 (mul a b) r c i j =
get (mul (block a r dr 0 a.cols) (block b 0 b.rows c dc)) i j }
= let a' = block a r dr 0 a.cols in
let b' = block b 0 b.rows c dc in
sum_ext (mul_atom a b (i + r) (j + c)) (mul_atom a' b' i j) 0 a.cols;
assert { ofs2 (mul a b) r c i j = get (mul a b) (i+r) (j+c)
= sum (mul_atom a b (i+r) (j+c)) 0 a.cols
= sum (mul_atom a' b' i j) 0 a.cols
= get (mul a' b') i j }
lemma mul_block:
forall a b: mat, r dr c dc: int.
a.cols = b.rows -> 0 <= r <= r + dr <= a.rows ->
0 <= c <= c + dc <= b.cols ->
let a' = block a r dr 0 a.cols in
let b' = block b 0 b.rows c dc in
let m' = block (mul a b) r dr c dc in
m' = mul a' b'
by m' == mul a' b'
predicate quarters (a a11 a12 a21 a22: mat) =
(is_finite a /\ is_finite a11 /\ is_finite a12 /\ is_finite a21 /\ is_finite a22) /\
(rows a11 = rows a12 = rows a21 = rows a22 = cols a11 = cols a12 = cols a21 = cols a22) /\
rows a = cols a = 2 * rows a11 /\
a11 = block a 0 a11.rows 0 a11.cols /\ a12 = block a 0 a11.rows a11.cols a11.cols /\
a21 = block a a11.rows a11.rows 0 a11.cols /\ a22 = block a a11.rows a11.rows a11.cols a11.cols
(*
let lemma naive_blocks (a b c a11 a12 a21 a22 b11 b12 b21 b22 c11 c12 c21 c22: mat)
requires { is_finite a /\ is_finite b /\ is_finite c }
requires { quarters a a11 a12 a21 a22 }
requires { quarters b b11 b12 b21 b22 }
requires { quarters c c11 c12 c21 c22 }
requires { c11 = add (mul a11 b11) (mul b11 b22) }
requires { c12 = add (mul a11 b12) (mul a12 b22) }
requires { c21 = add (mul a21 b11) (mul a22 b21) }
requires { c22 = add (mul a21 b12) (mul a22 b22) }
ensures { c = mul a b }
=
assert { c == mul a b } Z3 proves this ?
*)
let lemma naive_blocks (a b c a11 a12 a21 a22 b11 b12 b21 b22 c11 c12 c21 c22: mat)
requires { is_finite a /\ is_finite b /\ is_finite c }
requires { quarters a a11 a12 a21 a22 }
requires { quarters b b11 b12 b21 b22 }
requires { quarters c c11 c12 c21 c22 }
requires { c11 = add (mul a11 b11) (mul a12 b21) }
requires { c12 = add (mul a11 b12) (mul a12 b22) }
requires { c21 = add (mul a21 b11) (mul a22 b21) }
requires { c22 = add (mul a21 b12) (mul a22 b22) }
ensures { c = mul a b }
=
assert { c == mul a b }
use import int.Power
use import number.Parity
use import int.ComputerDivision
let ghost function cut_quarters (a: mat) : (mat, mat, mat, mat)
requires { is_finite a }
requires { rows a = cols a }
requires { even (rows a) }
returns { (a11, a12, a21, a22) -> quarters a a11 a12 a21 a22 }
=
let s = div (rows a) 2 in
(block a 0 s 0 s, block a 0 s s s, block a s s 0 s, block a s s s s)
let ghost function paste_quarters (a11 a12 a21 a22: mat): mat
requires { is_finite a11 /\ is_finite a12 /\ is_finite a21 /\ is_finite a22 }
requires { rows a11 = rows a12 = rows a21 = rows a22
= cols a11 = cols a12 = cols a21 = cols a22 }
ensures { quarters result a11 a12 a21 a22 }
=
let s = rows a11 in
let r = fcreate (2 * s) (2 * s)
(fun i j -> if i < s && j < s then get a11 i j
else if i < s then get a12 i (j-s)
else if j < s then get a21 (i-s) j
else get a22 (i-s) (j-s)) in
assert { a11 = block r 0 s 0 s by a11 == block r 0 s 0 s };
assert { a12 = block r 0 s s s by a12 == block r 0 s s s };
assert { a21 = block r s s 0 s by a21 == block r s s 0 s };
assert { a22 = block r s s s s by a22 == block r s s s s };
r
meta "compute_max_steps" 0x100000
let rec ghost function strassen_pow2 (a b: mat) (ghost k: int)
requires { 0 <= k }
requires { size a (power 2 k) (power 2 k) }
requires { size b (power 2 k) (power 2 k) }
ensures { result = mul a b }
variant { k }
=
let cutoff = begin ensures { result >= 1 } 4 end in
if k <= cutoff then mul a b
else begin
let (a11, a12, a21, a22) = cut_quarters a in
let (b11, b12, b21, b22) = cut_quarters b in
let s = power 2 (k-1) in
assert { s > 0 by k-1 >= 1 so power 2 (k-1) >= power 2 1 = 2};
assert { size a11 s s /\ size a12 s s /\ size a21 s s /\ size a22 s s };
assert { size b11 s s /\ size b12 s s /\ size b21 s s /\ size b22 s s };
let ghost c11 = add (mul a11 b11) (mul a12 b21) in
let ghost c12 = add (mul a11 b12) (mul a12 b22) in
let ghost c21 = add (mul a21 b11) (mul a22 b21) in
let ghost c22 = add (mul a21 b12) (mul a22 b22) in
mul_sizes a11 b11 s s s;
assert { size c11 s s /\ size c12 s s /\ size c21 s s /\ size c22 s s };
let ghost c = paste_quarters c11 c12 c21 c22 in
assert { c = mul a b };
let m1 = strassen_pow2 (add a11 a22) (add b11 b22) (k-1) in
let m2 = strassen_pow2 (add a21 a22) b11 (k-1) in
let m3 = strassen_pow2 a11 (add b12 (extp (-1) b22)) (k-1) in
let m4 = strassen_pow2 a22 (add b21 (extp (-1) b11)) (k-1) in
let m5 = strassen_pow2 (add a11 a12) b22 (k-1) in
let m6 = strassen_pow2 (add a21 (extp (-1) a11)) (add b11 b12) (k-1) in
let m7 = strassen_pow2 (add a12 (extp (-1) a22)) (add b21 b22) (k-1) in
let s11 = add m1 (add m4 (add m7 (extp (-1) m5))) in
let s12 = add m3 m5 in
let s21 = add m2 m4 in
let s22 = add m1 (add m3 (add m6 (extp (-1) m2))) in
assert { s11 = c11 };
assert { s12 = c12 };
assert { s21 = c21 };
assert { s22 = c22 };
paste_quarters s11 s12 s21 s22
end
end
\ No newline at end of file
......@@ -126,6 +126,7 @@ let reify_goal interp task =
with Exit -> invert_pat vl (env, fr) (p2, f) t
end
| Pvar _, Tvar _, Tvar _ | Pvar _, Tvar _, Tapp (_, [])
| Pvar _, Tvar _, Tconst _
-> if debug then Format.printf "case vars@.";
(env, fr, t)
| Pvar _, Tapp (ls, _la), _ when ls_equal ls interp
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment