Mentions légales du service

Skip to content
Snippets Groups Projects
Commit de45fab9 authored by hhakim's avatar hhakim
Browse files

Add abstract "MatGeneric<FPP, GPU2>::multiply(MatDense<FPP, GPU2>&, const...

Add abstract "MatGeneric<FPP, GPU2>::multiply(MatDense<FPP, GPU2>&, const char) const" and a fix of its overload in MatDense<FPP, GPU2>.
parent 60867b66
Branches
Tags
No related merge requests found
......@@ -94,7 +94,7 @@ namespace Faust
void Display() const;
Real<FPP> norm() const;
void multiply(MatDense<FPP, GPU2> &other, const char op_this);
void multiply(MatDense<FPP, GPU2> &other, const char op_this) const;
void multiply(MatSparse<FPP, GPU2> &other, const char op_this);
MatSparse<FPP, GPU2> toMatSparse() const;
~MatButterfly();
......
......@@ -24,7 +24,7 @@ namespace Faust
template<typename FPP>
void MatButterfly<FPP, GPU2>::multiply(MatDense<FPP, GPU2> &other, const char op_this)
void MatButterfly<FPP, GPU2>::multiply(MatDense<FPP, GPU2> &other, const char op_this) const
{
if(op_this != 'N' && op_this != 'T')
throw std::runtime_error("MatButtermfly::multiply only handle 'N' and 'T' for op_this");
......
......@@ -67,17 +67,6 @@ namespace Faust
const void* stream/*=nullptr*/) : MatDense<@FAUST_SCALAR_FOR_GM@,GPU2>(mat.getNbRow(), mat.getNbCol(), mat.getData(), /*no_alloc*/ mat.getData() == nullptr, dev_id, stream){}
template<>
void Faust::MatDense<@FAUST_SCALAR_FOR_GM@,GPU2>::multiply(const MatDense<@FAUST_SCALAR_FOR_GM@, GPU2> &other, const char op_this)
{
// other = this * other
gm_Op gop_this;
char2gm_Op(op_this, gop_this);
auto dsm_funcs = GPUModHandler::get_singleton()->dsm_funcs(@FAUST_SCALAR_FOR_GM@(0));
dsm_funcs->mul_gpu_dsm_ext(this->gpu_mat, other.gpu_mat, other.gpu_mat, gop_this, OP_NOTRANSP);
}
template<>
void Faust::MatDense<@FAUST_SCALAR_FOR_GM@,GPU2>::multiply(MatDense<@FAUST_SCALAR_FOR_GM@, Cpu> &other, const char op_this)
{
......@@ -632,6 +621,28 @@ namespace Faust
return *this;
}
template<>
void Faust::MatDense<@FAUST_SCALAR_FOR_GM@,GPU2>::multiply(MatDense<@FAUST_SCALAR_FOR_GM@, GPU2> &other, const char op_this) const
{
// other = this * other
gm_Op gop_this;
char2gm_Op(op_this, gop_this);
auto dsm_funcs = GPUModHandler::get_singleton()->dsm_funcs(@FAUST_SCALAR_FOR_GM@(0));
faust_unsigned_int m, n; // prod dims
m = op_this == 'N'?this->getNbRow():this->getNbCol();
n = other.getNbCol();
// is other large enough to receive this * other ?
if(other.getNbRow() < m)
{
MatDense<@FAUST_SCALAR_FOR_GM@, GPU2> out(m, n);
dsm_funcs->mul_gpu_dsm_ext(this->gpu_mat, other.gpu_mat, out.gpu_mat, gop_this, OP_NOTRANSP);
other = std::move(out);
}
else
dsm_funcs->mul_gpu_dsm_ext(this->gpu_mat, other.gpu_mat, other.gpu_mat, gop_this, OP_NOTRANSP);
}
template<>
MatDense<@FAUST_SCALAR_FOR_GM@,GPU2>::MatDense(MatDense<@FAUST_SCALAR_FOR_GM@,GPU2>&& mat)
{
......
......@@ -103,7 +103,7 @@ namespace Faust
void eltwise_mul(const Vect<FPP, GPU2> &vec, const int *ids=nullptr);
// TODO: other shouldn't be const if it is the output
// other = (*this) * other
void multiply(const MatDense<FPP, GPU2> &other, const char op_this='N');
void multiply(MatDense<FPP, GPU2> &other, const char op_this='N') const;
// other = (*this) * other
void multiply(MatDense<FPP, Cpu> &other, const char op_this='N');
// void multiply(MatSparse<FPP, Cpu> &other, MatDense<FPP, GPU2>& output, const char op_this='N');
......
......@@ -46,6 +46,8 @@ namespace Faust
virtual void Display() const=0;
virtual Real<FPP> norm() const=0;
virtual void multiply(MatDense<FPP,GPU2> &A, const char opThis) const =0;
MatGeneric();
virtual ~MatGeneric();
......
......@@ -90,7 +90,7 @@ namespace Faust
void Display() const;
Real<FPP> norm() const;
void multiply(MatDense<FPP, GPU2> &other, const char op_this);
void multiply(MatDense<FPP, GPU2> &other, const char op_this) const;
void multiply(MatSparse<FPP, GPU2> &other, const char op_this);
MatSparse<FPP, GPU2> toMatSparse() const;
......
......@@ -21,7 +21,7 @@ namespace Faust
template<typename FPP>
void MatPerm<FPP, GPU2>::multiply(MatDense<FPP, GPU2> &other, const char op_this)
void MatPerm<FPP, GPU2>::multiply(MatDense<FPP, GPU2> &other, const char op_this) const
{
if(op_this != 'N' && op_this != 'T')
throw std::runtime_error("MatButtermfly::multiply only handle 'N' and 'T' for op_this");
......
......@@ -474,7 +474,7 @@ namespace Faust
}
template<>
void MatSparse<FSFG, GPU2>::multiply(MatDense<FSFG,GPU2>& mat, char opThis/*='N'*/) const
void MatSparse<FSFG, GPU2>::multiply(MatDense<FSFG,GPU2>& mat, const char opThis/*='N'*/) const
{
gm_Op gop_this;
char2gm_Op(opThis, gop_this);
......
......@@ -100,7 +100,7 @@ namespace Faust
std::string to_string(const bool transpose=false, const bool displaying_small_mat_elts=false) const;
MatType getType() const;
void multiply(Vect<FPP,GPU2>& vec, char opThis='N') const;
void multiply(MatDense<FPP,GPU2>& vec, char opThis='N') const;
void multiply(MatDense<FPP,GPU2>& mat, const char opThis='N') const;
static void spgemm(const MatSparse<FPP,GPU2> & A, const MatDense<FPP,GPU2> & B, MatDense<FPP,GPU2> & C, const FPP & alpha, const FPP & beta, const char opA, const char opB);
MatBSR<FPP, GPU2> to_bsr(int bsize) const;
~MatSparse();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment