Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 6e148bf3 authored by hhakim's avatar hhakim
Browse files

Avoid useless copies of A and B matrices in GPU2 spgemm and bsrgemm.

parent a03a4d3a
Branches
Tags 3.32.4
No related merge requests found
Pipeline #834109 skipped
...@@ -27,72 +27,72 @@ namespace Faust ...@@ -27,72 +27,72 @@ namespace Faust
{ {
// transpose / adjoint the product to rely on other signature of spgemm (MatSparse B as lhs matrix -- i.e. A) // transpose / adjoint the product to rely on other signature of spgemm (MatSparse B as lhs matrix -- i.e. A)
char nopA, nopB; char nopA, nopB;
MatDense<FPP, GPU2> nA(A);
MatSparse<FPP, GPU2> nB(B);
if(opA == 'N' && opB == 'N') if(opA == 'N' && opB == 'N')
{ {
nopA = 'T'; nopA = 'T';
nopB = 'T'; nopB = 'T';
C.resize(nB.getNbCol(), nA.getNbRow()); C.resize(B.getNbCol(), A.getNbRow());
spgemm(nB, nA, C, alpha, beta, nopB, nopA); spgemm(B, A, C, alpha, beta, nopB, nopA);
C.transpose(); C.transpose();
} }
else if(opA == 'N' && opB == 'T') else if(opA == 'N' && opB == 'T')
{ {
nopA = 'T'; nopA = 'T';
C.resize(nB.getNbRow(), nA.getNbRow()); C.resize(B.getNbRow(), A.getNbRow());
spgemm(nB, nA, C, alpha, beta, opB, nopA); spgemm(B, A, C, alpha, beta, opB, nopA);
C.transpose(); C.transpose();
} }
else if(opA == 'T' && opB == 'N') else if(opA == 'T' && opB == 'N')
{ {
nopB = 'T'; nopB = 'T';
C.resize(nB.getNbCol(), nA.getNbCol()); C.resize(B.getNbCol(), A.getNbCol());
spgemm(nB, nA, C, alpha, beta, nopB, opA); spgemm(B, A, C, alpha, beta, nopB, opA);
C.transpose(); C.transpose();
} }
else if(opA == 'T' && opB == 'T') else if(opA == 'T' && opB == 'T')
{ {
C.resize(nB.getNbRow(), nA.getNbCol()); C.resize(B.getNbRow(), A.getNbCol());
spgemm(nB, nA, C, alpha, beta, opB, opA); spgemm(B, A, C, alpha, beta, opB, opA);
C.transpose(); C.transpose();
} }
else if(opA == 'N' && opB == 'H') else if(opA == 'N' && opB == 'H')
{ {
nopA = 'H'; nopA = 'H';
C.resize(nB.getNbRow(), nA.getNbRow()); C.resize(B.getNbRow(), A.getNbRow());
spgemm(nB, nA, C, alpha, beta, opB, nopA); spgemm(B, A, C, alpha, beta, opB, nopA);
C.adjoint(); C.adjoint();
} }
else if(opA == 'H' && opB == 'N') else if(opA == 'H' && opB == 'N')
{ {
nopB = 'H'; nopB = 'H';
C.resize(nB.getNbCol(), nA.getNbCol()); C.resize(B.getNbCol(), A.getNbCol());
spgemm(nB, nA, C, alpha, beta, nopB, opA); spgemm(B, A, C, alpha, beta, nopB, opA);
C.adjoint(); C.adjoint();
} }
else if(opA == 'H' && opB == 'H') else if(opA == 'H' && opB == 'H')
{ {
C.resize(nB.getNbRow(), nA.getNbCol()); C.resize(B.getNbRow(), A.getNbCol());
spgemm(nB, nA, C, alpha, beta, opB, opA); spgemm(B, A, C, alpha, beta, opB, opA);
C.adjoint(); C.adjoint();
} }
else if(opA == 'H' && opB == 'T') else if(opA == 'H' && opB == 'T')
{ {
nopA = 'N'; nopA = 'N';
MatSparse<FPP, GPU2> nB(B);
nB.conjugate(); nB.conjugate();
nopB = 'N'; nopB = 'N';
C.resize(nB.getNbRow(), nA.getNbCol()); C.resize(nB.getNbRow(), A.getNbCol());
spgemm(nB, nA, C, alpha, beta, nopB, nopA); spgemm(nB, A, C, alpha, beta, nopB, nopA);
C.adjoint(); C.adjoint();
} }
else if(opA == 'T' && opB == 'H') else if(opA == 'T' && opB == 'H')
{ {
MatDense<FPP, GPU2> nA(A);
nA.conjugate(); nA.conjugate();
nopA = 'N'; nopA = 'N';
nopB = 'N'; nopB = 'N';
C.resize(nB.getNbRow(), nA.getNbCol()); C.resize(B.getNbRow(), nA.getNbCol());
spgemm(nB, nA, C, alpha, beta, nopB, nopA); spgemm(B, nA, C, alpha, beta, nopB, nopA);
C.adjoint(); C.adjoint();
} }
} }
...@@ -163,72 +163,73 @@ namespace Faust ...@@ -163,72 +163,73 @@ namespace Faust
{ {
// transpose / adjoint the product to rely on other signature of bsrgemm (MatSparse B as lhs matrix -- i.e. A) // transpose / adjoint the product to rely on other signature of bsrgemm (MatSparse B as lhs matrix -- i.e. A)
char nopA, nopB; char nopA, nopB;
MatDense<FPP, GPU2> nA(A);
MatBSR<FPP, GPU2> nB(B);
if(opA == 'N' && opB == 'N') if(opA == 'N' && opB == 'N')
{ {
nopA = 'T'; nopA = 'T';
nopB = 'T'; nopB = 'T';
C.resize(nB.getNbCol(), nA.getNbRow()); C.resize(B.getNbCol(), A.getNbRow());
bsrgemm(nB, nA, C, alpha, beta, nopB, nopA); bsrgemm(B, A, C, alpha, beta, nopB, nopA);
C.transpose(); C.transpose();
} }
else if(opA == 'N' && opB == 'T') else if(opA == 'N' && opB == 'T')
{ {
nopA = 'T'; nopA = 'T';
C.resize(nB.getNbRow(), nA.getNbRow()); C.resize(B.getNbRow(), A.getNbRow());
bsrgemm(nB, nA, C, alpha, beta, opB, nopA); bsrgemm(B, A, C, alpha, beta, opB, nopA);
C.transpose(); C.transpose();
} }
else if(opA == 'T' && opB == 'N') else if(opA == 'T' && opB == 'N')
{ {
nopB = 'T'; nopB = 'T';
C.resize(nB.getNbCol(), nA.getNbCol()); C.resize(B.getNbCol(), A.getNbCol());
bsrgemm(nB, nA, C, alpha, beta, nopB, opA); bsrgemm(B, A, C, alpha, beta, nopB, opA);
C.transpose(); C.transpose();
} }
else if(opA == 'T' && opB == 'T') else if(opA == 'T' && opB == 'T')
{ {
C.resize(nB.getNbRow(), nA.getNbCol()); C.resize(B.getNbRow(), A.getNbCol());
bsrgemm(nB, nA, C, alpha, beta, opB, opA); bsrgemm(B, A, C, alpha, beta, opB, opA);
C.transpose(); C.transpose();
} }
else if(opA == 'N' && opB == 'H') else if(opA == 'N' && opB == 'H')
{ {
nopA = 'H'; nopA = 'H';
C.resize(nB.getNbRow(), nA.getNbRow()); C.resize(B.getNbRow(), A.getNbRow());
bsrgemm(nB, nA, C, alpha, beta, opB, nopA); bsrgemm(B, A, C, alpha, beta, opB, nopA);
C.adjoint(); C.adjoint();
} }
else if(opA == 'H' && opB == 'N') else if(opA == 'H' && opB == 'N')
{ {
nopB = 'H'; nopB = 'H';
C.resize(nB.getNbCol(), nA.getNbCol()); C.resize(B.getNbCol(), A.getNbCol());
bsrgemm(nB, nA, C, alpha, beta, nopB, opA); bsrgemm(B, A, C, alpha, beta, nopB, opA);
C.adjoint(); C.adjoint();
} }
else if(opA == 'H' && opB == 'H') else if(opA == 'H' && opB == 'H')
{ {
C.resize(nB.getNbRow(), nA.getNbCol()); C.resize(B.getNbRow(), A.getNbCol());
bsrgemm(nB, nA, C, alpha, beta, opB, opA); bsrgemm(B, A, C, alpha, beta, opB, opA);
C.adjoint(); C.adjoint();
} }
else if(opA == 'H' && opB == 'T') else if(opA == 'H' && opB == 'T')
{ {
nopA = 'N'; nopA = 'N';
MatBSR<FPP, GPU2> nB(B);
nB.conjugate(); nB.conjugate();
nopB = 'N'; nopB = 'N';
C.resize(nB.getNbRow(), nA.getNbCol()); C.resize(nB.getNbRow(), A.getNbCol());
bsrgemm(nB, nA, C, alpha, beta, nopB, nopA); bsrgemm(nB, A, C, alpha, beta, nopB, nopA);
C.adjoint(); C.adjoint();
} }
else if(opA == 'T' && opB == 'H') else if(opA == 'T' && opB == 'H')
{ {
MatDense<FPP, GPU2> nA(A);
nA.conjugate(); nA.conjugate();
nopA = 'N'; nopA = 'N';
nopB = 'N'; nopB = 'N';
C.resize(nB.getNbRow(), nA.getNbCol()); C.resize(B.getNbRow(), nA.getNbCol());
bsrgemm(nB, nA, C, alpha, beta, nopB, nopA); bsrgemm(B, nA, C, alpha, beta, nopB, nopA);
C.adjoint(); C.adjoint();
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment