Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 423c4532 authored by hhakim's avatar hhakim
Browse files

Update to gpu_mod@a353f788 and add gemm function for MatDense<FPP,GPU2>.

parent e6ce2cf0
No related branches found
No related tags found
No related merge requests found
......@@ -10,9 +10,14 @@ namespace Faust
template<typename FPP,FDevice DEVICE>
class MatDense;
template <typename FPP>
void gemm(const MatDense<FPP, GPU2> &A, const MatDense<FPP, GPU2> &B, MatDense<FPP, GPU2> &C, const FPP& alpha, const FPP& beta, const char opA, const char opB);
template<typename FPP>
class MatDense<FPP, GPU2> : MatDense<FPP, Cpu>
{
friend void gemm<>(const MatDense<FPP, GPU2> &A, const MatDense<FPP, GPU2> &B, MatDense<FPP, GPU2> &C, const FPP& alpha, const FPP& beta, const char opA, const char opB);
public:
MatDense(const faust_unsigned_int nbRow,
const faust_unsigned_int nbCol,
......@@ -83,6 +88,7 @@ namespace Faust
gm_DenseMat_t gpu_mat;
};
template <typename FPP>
void* Faust::MatDense<FPP,GPU2>::dsm_funcs = nullptr;
......
......@@ -415,4 +415,22 @@ namespace Faust
this->dim1 = A.dim1;
this->dim2 = A.dim2;
}
template <>
void gemm(const MatDense<@FAUST_SCALAR_FOR_GM@, GPU2> &A, const MatDense<@FAUST_SCALAR_FOR_GM@, GPU2> &B, MatDense<@FAUST_SCALAR_FOR_GM@, GPU2> &C,
const @FAUST_SCALAR_FOR_GM@& alpha, const @FAUST_SCALAR_FOR_GM@& beta, const char opA, const char opB)
{
gm_Op gop_A = OP_NOTRANSP, gop_B = OP_NOTRANSP;
if(opA == 'T')
gop_A = OP_TRANSP;
else if(opA == 'H')
gop_A = OP_CONJTRANSP;
if(opB == 'T')
gop_B = OP_TRANSP;
else if(opB == 'H')
gop_B = OP_CONJTRANSP;
auto dsm_funcs = ((gm_DenseMatFunc_@GM_SCALAR@*) MatDense<@FAUST_SCALAR_FOR_GM@, GPU2>::dsm_funcs);
dsm_funcs->gemm(A.gpu_mat, B.gpu_mat, C.gpu_mat, reinterpret_cast<const @GM_REINTERPRET_CAST_SCALAR@*>(&alpha), reinterpret_cast<const @GM_REINTERPRET_CAST_SCALAR@*>(&beta), gop_A, gop_B);
}
};
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment