Mentions légales du service

Skip to content
Snippets Groups Projects
Commit f7f3aa08 authored by hhakim's avatar hhakim
Browse files

Integrate BSR matrix in GPU2 gemm_gen using bsrgemm.

parent 486e8b27
Branches
Tags
No related merge requests found
...@@ -109,21 +109,44 @@ namespace Faust ...@@ -109,21 +109,44 @@ namespace Faust
const MatSparse<FPP, GPU2>* spB; const MatSparse<FPP, GPU2>* spB;
const MatDense<FPP, GPU2>* dsA; const MatDense<FPP, GPU2>* dsA;
const MatDense<FPP, GPU2>* dsB; const MatDense<FPP, GPU2>* dsB;
const MatBSR<FPP, GPU2>* bsrA;
const MatBSR<FPP, GPU2>* bsrB;
// downcast and call the proper function // downcast and call the proper function
spA = dynamic_cast<const Faust::MatSparse<FPP,GPU2>*>(&A); spA = dynamic_cast<const Faust::MatSparse<FPP,GPU2>*>(&A);
if(! spA) if(! spA)
{
dsA = dynamic_cast<const Faust::MatDense<FPP,GPU2>*>(&A); dsA = dynamic_cast<const Faust::MatDense<FPP,GPU2>*>(&A);
if (! dsA)
bsrA = dynamic_cast<const Faust::MatBSR<FPP, GPU2>*>(&A);
}
spB = dynamic_cast<const Faust::MatSparse<FPP,GPU2>*>(&B); spB = dynamic_cast<const Faust::MatSparse<FPP,GPU2>*>(&B);
if(! spB) if(! spB)
{
dsB = dynamic_cast<const Faust::MatDense<FPP,GPU2>*>(&B); dsB = dynamic_cast<const Faust::MatDense<FPP,GPU2>*>(&B);
if (! dsB)
bsrB = dynamic_cast<const Faust::MatBSR<FPP, GPU2>*>(&B);
}
if(spA && spB) if(spA && spB)
spgemm(*spA, MatDense<FPP, GPU2>(*spB), C, alpha, beta, typeA, typeB); spgemm(*spA, MatDense<FPP, GPU2>(*spB), C, alpha, beta, typeA, typeB);
else if(spA) else if(spA && dsB)
spgemm(*spA, *dsB, C, alpha, beta, typeA, typeB); spgemm(*spA, *dsB, C, alpha, beta, typeA, typeB);
else if(spB) else if(spB && dsA)
spgemm(*dsA, *spB, C, alpha, beta, typeA, typeB); spgemm(*dsA, *spB, C, alpha, beta, typeA, typeB);
else else if (dsA && dsB)
gemm(*dsA, *dsB, C, alpha, beta, typeA, typeB); gemm(*dsA, *dsB, C, alpha, beta, typeA, typeB);
else if (bsrA && dsB)
bsrgemm(*bsrA, *dsB, C, alpha, beta, typeA, typeB);
else if (bsrA && spB)
bsrgemm(*bsrA, MatDense<FPP, GPU2>(*spB), C, alpha, beta, typeA, typeB);
else if (bsrA && bsrB)
bsrgemm(*bsrA, bsrB->to_dense(), C, alpha, beta, typeA, typeB); // TODO: consider also converting bsrB to MatDense, depending on the weight
else if (bsrB && dsA)
bsrgemm(*dsA, *bsrB, C, alpha, beta, typeA, typeB);
else if (bsrB && spA)
bsrgemm(MatDense<FPP, GPU2>(*spA), *bsrB, C, alpha, beta, typeA, typeB); // TODO: consider also converting bsrB to MatDense, depending on the weight // to test
else
throw std::runtime_error("Unsupported matrix type in faust_linear_algebra_gpu gemm_gen");
} }
template<typename FPP> template<typename FPP>
...@@ -141,7 +164,7 @@ namespace Faust ...@@ -141,7 +164,7 @@ namespace Faust
// transpose / adjoint the product to rely on other signature of bsrgemm (MatSparse B as lhs matrix -- i.e. A) // transpose / adjoint the product to rely on other signature of bsrgemm (MatSparse B as lhs matrix -- i.e. A)
char nopA, nopB; char nopA, nopB;
MatDense<FPP, GPU2> nA(A); MatDense<FPP, GPU2> nA(A);
MatDense<FPP, GPU2> nB(B); MatBSR<FPP, GPU2> nB(B);
if(opA == 'N' && opB == 'N') if(opA == 'N' && opB == 'N')
{ {
nopA = 'T'; nopA = 'T';
...@@ -210,7 +233,7 @@ namespace Faust ...@@ -210,7 +233,7 @@ namespace Faust
} }
} }
else { else {
bsrgemm(MatDense<FPP, GPU2>(A), MatDense<FPP, GPU2>(B), C, alpha, beta, opA, opB); spgemm(A, B.to_sparse(), C, alpha, beta, opA, opB);
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment