Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 6e193ade authored by hhakim's avatar hhakim
Browse files

Implement bsrgemm in faust_linear_algebra GPU2 module.

parent 528ac10f
Branches
Tags
No related merge requests found
......@@ -28,7 +28,15 @@ namespace Faust
// \param impl_meth: in any case this function rely on previous spgemm prototype, if impl_meth is 1 then transpose/transconjugate is used to avoid converting A and B to another type of matrix, otherwise (impl_meth is any other value) A is converted to a MatSparse and B to a MatDense
template<typename FPP>
void spgemm(const MatDense<FPP,GPU2> & A,const MatSparse<FPP,GPU2> & B, MatDense<FPP,GPU2> & C,const FPP & alpha, const FPP & beta, char opA, char opB, int impl_meth = 1);
//
// Computes alpha*opA(A)*opB(B)+ beta*C into C.
template<typename FPP>
void bsrgemm(const MatBSR<FPP,GPU2> & A,const MatDense<FPP,GPU2> & B, MatDense<FPP,GPU2> & C,const FPP & alpha, const FPP & beta, char opA, char opB);
// Computes alpha*opA(A)*opB(B)+ beta*C into C.
// \param impl_meth: in any case this function rely on previous bsrgemm prototype, if impl_meth is 1 then transpose/transconjugate is used to avoid converting A and B to another type of matrix, otherwise (impl_meth is any other value) A is converted to a MatSparse and B to a MatDense
template<typename FPP>
void bsrgemm(const MatDense<FPP,GPU2> & A,const MatBSR<FPP,GPU2> & B, MatDense<FPP,GPU2> & C,const FPP & alpha, const FPP & beta, char opA, char opB, int impl_meth = 1);
}
#include "faust_linear_algebra_gpu.hpp"
#endif
......@@ -117,7 +117,7 @@ namespace Faust
if(! spB)
dsB = dynamic_cast<const Faust::MatDense<FPP,GPU2>*>(&B);
if(spA && spB)
throw std::runtime_error("gemm on two MatSparse is not supported.");
spgemm(*spA, MatDense<FPP, GPU2>(*spB), C, alpha, beta, typeA, typeB);
else if(spA)
spgemm(*spA, *dsB, C, alpha, beta, typeA, typeB);
else if(spB)
......@@ -125,4 +125,95 @@ namespace Faust
else
gemm(*dsA, *dsB, C, alpha, beta, typeA, typeB);
}
template<typename FPP>
void bsrgemm(const MatBSR<FPP,GPU2> & A,const MatDense<FPP,GPU2> & B, MatDense<FPP,GPU2> & C,const FPP & alpha, const FPP & beta, char opA, char opB)
{
MatBSR<FPP, GPU2>::bsrgemm(A, B, C, alpha, beta, opA, opB);
}
template<typename FPP>
void bsrgemm(const MatDense<FPP,GPU2> & A, const MatBSR<FPP,GPU2> & B, MatDense<FPP,GPU2> & C, const FPP & alpha, const FPP & beta, char opA, char opB, int impl_meth/* = 1*/)
{
//TODO: benchmark the two methods (impl_meth == 1 and 2)
if (impl_meth == 1)
{
// transpose / adjoint the product to rely on other signature of bsrgemm (MatSparse B as lhs matrix -- i.e. A)
char nopA, nopB;
MatDense<FPP, GPU2> nA(A);
MatDense<FPP, GPU2> nB(B);
if(opA == 'N' && opB == 'N')
{
nopA = 'T';
nopB = 'T';
C.resize(nB.getNbCol(), nA.getNbRow());
bsrgemm(nB, nA, C, alpha, beta, nopB, nopA);
C.transpose();
}
else if(opA == 'N' && opB == 'T')
{
nopA = 'T';
C.resize(nB.getNbRow(), nA.getNbRow());
bsrgemm(nB, nA, C, alpha, beta, opB, nopA);
C.transpose();
}
else if(opA == 'T' && opB == 'N')
{
nopB = 'T';
C.resize(nB.getNbCol(), nA.getNbCol());
bsrgemm(nB, nA, C, alpha, beta, nopB, opA);
C.transpose();
}
else if(opA == 'T' && opB == 'T')
{
C.resize(nB.getNbRow(), nA.getNbCol());
bsrgemm(nB, nA, C, alpha, beta, opB, opA);
C.transpose();
}
else if(opA == 'N' && opB == 'H')
{
nopA = 'H';
C.resize(nB.getNbRow(), nA.getNbRow());
bsrgemm(nB, nA, C, alpha, beta, opB, nopA);
C.adjoint();
}
else if(opA == 'H' && opB == 'N')
{
nopB = 'H';
C.resize(nB.getNbCol(), nA.getNbCol());
bsrgemm(nB, nA, C, alpha, beta, nopB, opA);
C.adjoint();
}
else if(opA == 'H' && opB == 'H')
{
C.resize(nB.getNbRow(), nA.getNbCol());
bsrgemm(nB, nA, C, alpha, beta, opB, opA);
C.adjoint();
}
else if(opA == 'H' && opB == 'T')
{
nopA = 'N';
nB.conjugate();
nopB = 'N';
C.resize(nB.getNbRow(), nA.getNbCol());
bsrgemm(nB, nA, C, alpha, beta, nopB, nopA);
C.adjoint();
}
else if(opA == 'T' && opB == 'H')
{
nA.conjugate();
nopA = 'N';
nopB = 'N';
C.resize(nB.getNbRow(), nA.getNbCol());
bsrgemm(nB, nA, C, alpha, beta, nopB, nopA);
C.adjoint();
}
}
else {
bsrgemm(MatDense<FPP, GPU2>(A), MatDense<FPP, GPU2>(B), C, alpha, beta, opA, opB);
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment