Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 77cc63c1 authored by hhakim's avatar hhakim
Browse files

Fix multiple errors in gemm_gen function which was modified recently to handle...

Fix multiple errors in gemm_gen function which was modified recently to handle MatBSR A or B arguments.
parent d2ad6ca7
Branches
Tags
No related merge requests found
...@@ -865,6 +865,7 @@ namespace Faust ...@@ -865,6 +865,7 @@ namespace Faust
{ {
template<typename FPP> void gemm_gen(const MatGeneric<FPP, Cpu>& A, const MatGeneric<FPP, Cpu>& B, MatDense<FPP, Cpu>& out, const FPP alpha/*=FPP(1.0)*/, const FPP beta/*=(0.0)*/, const char opA/*='N'*/, const char opB/*='N'*/) template<typename FPP> void gemm_gen(const MatGeneric<FPP, Cpu>& A, const MatGeneric<FPP, Cpu>& B, MatDense<FPP, Cpu>& out, const FPP alpha/*=FPP(1.0)*/, const FPP beta/*=(0.0)*/, const char opA/*='N'*/, const char opB/*='N'*/)
{ {
std::runtime_error type_err("faust_linear_algebra mul function doesn't handle other type of factors than MatDense, MatSparse or MatBSR.");
if(opA != 'N' && opA != 'T' && opA != 'H') if(opA != 'N' && opA != 'T' && opA != 'H')
throw std::runtime_error("opA must be among 'N', 'T', 'H'"); throw std::runtime_error("opA must be among 'N', 'T', 'H'");
if(opB != 'N' && opB != 'T' && opB != 'H') if(opB != 'N' && opB != 'T' && opB != 'H')
...@@ -936,26 +937,32 @@ namespace Faust ...@@ -936,26 +937,32 @@ namespace Faust
if(opB == 'N') if(opB == 'N')
dsm = bsrA->bmat.mul(spB->mat); dsm = bsrA->bmat.mul(spB->mat);
else if (opB == 'T') else if (opB == 'T')
dsm = bsrA->bmat.mul(spB->mat).transpose(); {
Eigen::SparseMatrix<FPP,Eigen::RowMajor> smat_t = const_cast<MatSparse<FPP, Cpu>*>(spB)->mat.transpose().eval();
dsm = bsrA->bmat.mul(smat_t);
}
else if (opB == 'H') else if (opB == 'H')
dsm = bsrA->bmat.mul(spB->mat).adjoint(); {
Eigen::SparseMatrix<FPP,Eigen::RowMajor> smat_a = const_cast<MatSparse<FPP, Cpu>*>(spB)->mat.adjoint().eval();
dsm = bsrA->bmat.mul(smat_a);
}
} }
else if(dsB = dynamic_cast<const MatDense<FPP, Cpu>*>(&B)) else if(dsB = dynamic_cast<const MatDense<FPP, Cpu>*>(&B))
{ {
if(opB == 'N') if(opB == 'N')
dsm = bsrA->bmat.mul(dsB->mat); dsm = bsrA->bmat.mul(dsB->mat);
else if (opB == 'T') else if (opB == 'T')
dsm = bsrA->bmat.mul(dsB->mat).transpose(); dsm = bsrA->bmat.mul(const_cast<MatDense<FPP, Cpu>*>(dsB)->mat.transpose());
else if (opB == 'H') else if (opB == 'H')
dsm = bsrA->bmat.mul(dsB->mat).adjoint(); dsm = bsrA->bmat.mul(const_cast<MatDense<FPP, Cpu>*>(dsB)->mat.transpose());
} }
else if(bsrB = dynamic_cast<const MatBSR<FPP, Cpu>*>(&B)) else if(bsrB = dynamic_cast<const MatBSR<FPP, Cpu>*>(&B))
if(opB == 'N') if(opB == 'N')
dsm = alpha*bsrA->bmat.mul(bsrB->bmat); dsm = alpha*bsrA->bmat.mul(bsrB->bmat);
else if (opB == 'T') else if (opB == 'T')
dsm = alpha*bsrA->bmat.mul(bsrB->bmat).transpose(); dsm = alpha*bsrA->bmat.mul(const_cast<MatBSR<FPP, Cpu>*>(bsrB)->bmat.transpose());
else // opB == H else // opB == H
dsm = alpha*bsrA->bmat.mul(bsrB->bmat).adjoint(); dsm = alpha*bsrA->bmat.mul(const_cast<MatBSR<FPP, Cpu>*>(bsrB)->bmat.adjoint());
if(beta == FPP(0)) if(beta == FPP(0))
{ {
...@@ -971,19 +978,29 @@ namespace Faust ...@@ -971,19 +978,29 @@ namespace Faust
{ {
MatBSR<FPP, Cpu> transpA(*bsrA); MatBSR<FPP, Cpu> transpA(*bsrA);
transpA.transpose(); transpA.transpose();
return gemm_gen(A, B, out, alpha, beta, 'N', opB); return gemm_gen(transpA, B, out, alpha, beta, 'N', opB);
} }
else if(opA == 'H') else if(opA == 'H')
{ {
MatBSR<FPP, Cpu> adjA(*bsrA); MatBSR<FPP, Cpu> adjA(*bsrA);
adjA.adjoint(); adjA.adjoint();
return gemm_gen(A, B, out, alpha, beta, 'N', opB); return gemm_gen(adjA, B, out, alpha, beta, 'N', opB);
} }
} }
else if(dynamic_cast<const MatBSR<FPP, Cpu>*>(&A)) else if(dynamic_cast<const MatBSR<FPP, Cpu>*>(&B))
{ {
//TODO: refactor //TODO: refactor
const MatBSR<FPP, Cpu>* bsrA = dynamic_cast<const MatBSR<FPP, Cpu>*>(&A); const MatBSR<FPP, Cpu>* bsrA = dynamic_cast<const MatBSR<FPP, Cpu>*>(&A);
const MatSparse<FPP, Cpu>* spA = nullptr;
const MatDense<FPP, Cpu>* dsA = nullptr;
auto bsrB = dynamic_cast<const MatBSR<FPP, Cpu>*>(&B);
if(!bsrA)
if(! (spA = dynamic_cast<const MatSparse<FPP, Cpu>*>(&A)))
{
dsA = dynamic_cast<const MatDense<FPP, Cpu>*>(&A);
if(!dsA)
throw type_err;
}
if(opA == 'N' && opB == 'N') if(opA == 'N' && opB == 'N')
{ {
gemm_gen(B, A, out, alpha, beta, 'T', 'T'); gemm_gen(B, A, out, alpha, beta, 'T', 'T');
...@@ -1011,9 +1028,9 @@ namespace Faust ...@@ -1011,9 +1028,9 @@ namespace Faust
} }
else if(opA == 'T' && opB == 'H') else if(opA == 'T' && opB == 'H')
{ {
MatBSR<FPP, Cpu> A_conj(*bsrA); MatBSR<FPP, Cpu> B_conj(*bsrB);
A_conj.conjugate(); B_conj.conjugate();
gemm_gen(B, A_conj, out, alpha, beta, 'N', 'N'); gemm_gen(B_conj, A, out, alpha, beta, 'N', 'N');
out.adjoint(); out.adjoint();
} }
if(opA == 'H' && opB == 'N') if(opA == 'H' && opB == 'N')
...@@ -1023,9 +1040,24 @@ namespace Faust ...@@ -1023,9 +1040,24 @@ namespace Faust
} }
else if(opA == 'H' && opB == 'T') else if(opA == 'H' && opB == 'T')
{ {
MatBSR<FPP, Cpu> A_conj(*bsrA); if(bsrA)
A_conj.conjugate(); {
gemm_gen(B, A_conj, out, alpha, beta, 'N', 'N'); MatBSR<FPP, Cpu> A_conj(*bsrA);
A_conj.conjugate();
gemm_gen(B, A_conj, out, alpha, beta, 'N', 'N');
}
else if(spA)
{
MatSparse<FPP, Cpu> A_conj(*spA);
A_conj.conjugate();
gemm_gen(B, A_conj, out, alpha, beta, 'N', 'N');
}
else //dsA
{
MatDense<FPP, Cpu> A_conj(*dsA);
A_conj.conjugate();
gemm_gen(B, A_conj, out, alpha, beta, 'N', 'N');
}
out.transpose(); out.transpose();
} }
else if(opA == 'H' && opB == 'H') else if(opA == 'H' && opB == 'H')
...@@ -1036,7 +1068,7 @@ namespace Faust ...@@ -1036,7 +1068,7 @@ namespace Faust
} }
else else
{ {
throw std::runtime_error("faust_linear_algebra mul function doesn't handle other type of factors than MatDense, MatSparse or MatBSR."); throw type_err;
} }
out.update_dims(); out.update_dims();
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment