Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 4b1447e9 authored by hhakim's avatar hhakim
Browse files

Fix bugs in MatButterfly::faust_gemm and implement a conjugate argument to multiply prototypes.

parent 699b17d7
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,8 @@ types::ndarray<T, types::pshape<long>> arrayFromBuf1D(T* fPtr, long size)
#include <Eigen/Core>
#include <memory> // shared_ptr
#define diag_conj(D) DiagMat(D.diagonal().conjugate()) // D is a DiagMat
namespace Faust
{
template<typename FPP, FDevice DEVICE> class MatButterfly;
......@@ -58,8 +60,8 @@ namespace Faust
faust_unsigned_int getNbRow() const { return D1.rows();};
faust_unsigned_int getNbCol() const { return D1.cols();};
void multiply(const FPP* x, FPP* y, size_t size, bool transpose = false);
void multiply(const FPP* A, int A_ncols, FPP* C, size_t size, bool transpose = false);
void multiply(const FPP* x, FPP* y, size_t size, bool transpose = false, bool conjugate = false);
void multiply(const FPP* A, int A_ncols, FPP* C, size_t size, bool transpose = false, bool conjugate = false);
const DiagMat& getD1() {return D1;}; //TODO/ move to .hpp
const DiagMat& getD2() {return D2;};
......
......@@ -140,7 +140,7 @@ namespace Faust
void MatButterfly<FPP, Cpu>::multiply(MatDense<FPP,Cpu> & M, char opThis) const
{
MatDense<FPP, Cpu> Y(this->getNbRow(), M.getNbCol());
const_cast<MatButterfly<FPP, Cpu>*>(this)->multiply(M.getData(), M.getNbCol(), Y.getData(), Y.getNbRow(), 'T' == opThis);
const_cast<MatButterfly<FPP, Cpu>*>(this)->multiply(M.getData(), M.getNbCol(), Y.getData(), Y.getNbRow(), 'N' != opThis, 'H' == opThis || 'C' == opThis);
M = Y;
}
......@@ -160,7 +160,7 @@ namespace Faust
template<typename FPP>
void MatButterfly<FPP, Cpu>::multiply(const FPP* x, FPP* y, size_t size, bool transpose/* = false*/)
void MatButterfly<FPP, Cpu>::multiply(const FPP* x, FPP* y, size_t size, bool transpose/* = false*/, bool conjugate/* = false*/)
{
DiagMat &D2 = this->D2;
const FPP *d1_ptr, *d2_ptr;
......@@ -187,20 +187,29 @@ namespace Faust
#ifdef BMAT_MULTIPLY_VEC_OMP_LOOP
#pragma omp parallel for
for(int i=0;i < size; i++)
y[i] = d1_ptr[i] * x[i] + d2_ptr[i] * x[subdiag_ids[i]];
if(conjugate)
y[i] = std::conj(d1_ptr[i]) * x[i] + std::conj(d2_ptr[i]) * x[subdiag_ids[i]];
else
y[i] = d1_ptr[i] * x[i] + d2_ptr[i] * x[subdiag_ids[i]];
#else
// this is slower
VecMap x_vec(const_cast<FPP*>(x), size); // const_cast is harmless
VecMap y_vec(y, size); // const_cast is harmless
if(do_transp)
y_vec = D1 * x_vec + D2T * x_vec(subdiag_ids);
if(conjugate)
y_vec = diag_conj(D1) * x_vec + diag_conj(D2T) * x_vec(subdiag_ids);
else
y_vec = D1 * x_vec + D2T * x_vec(subdiag_ids);
else
y_vec = D1 * x_vec + D2 * x_vec(subdiag_ids);
if(conjugate)
y_vec = diag_conj(D1) * x_vec + diag_conj(D2) * x_vec(subdiag_ids);
else
y_vec = D1 * x_vec + D2 * x_vec(subdiag_ids);
#endif
}
template<typename FPP>
void MatButterfly<FPP, Cpu>::multiply(const FPP* X, int X_ncols, FPP* Y, size_t Y_nrows, bool transpose/* = false*/)
void MatButterfly<FPP, Cpu>::multiply(const FPP* X, int X_ncols, FPP* Y, size_t Y_nrows, bool transpose/* = false*/, bool conjugate /* = false*/)
{
using MatMap = Eigen::Map<Eigen::Matrix<FPP, Eigen::Dynamic, Eigen::Dynamic>>;
MatMap X_mat(const_cast<FPP*>(X) /* harmless, no modification*/, Y_nrows, X_ncols);
......@@ -221,12 +230,22 @@ namespace Faust
// this is slower
#pragma omp parallel for
for(int i=0;i < Y_nrows; i++)
Y_mat.row(i) = d1_ptr[i] * X_mat.row(i) + d2_ptr[i] * X_mat.row(subdiag_ids[i]);
if(conjugate)
Y_mat.row(i) = std::conj(d1_ptr[i]) * X_mat.row(i) + std::conj(d2_ptr[i]) * X_mat.row(subdiag_ids[i]);
else
Y_mat.row(i) = d1_ptr[i] * X_mat.row(i) + d2_ptr[i] * X_mat.row(subdiag_ids[i]);
#else
// TODO: refactor the exp with a macro diag_prod
if(do_transp)
Y_mat = D1 * X_mat + D2T * X_mat(subdiag_ids, Eigen::placeholders::all);
if(conjugate)
Y_mat = diag_conj(D1) * X_mat + diag_conj(D2T) * X_mat(subdiag_ids, Eigen::placeholders::all);
else
Y_mat = D1 * X_mat + D2T * X_mat(subdiag_ids, Eigen::placeholders::all);
else
Y_mat = D1 * X_mat + D2 * X_mat(subdiag_ids, Eigen::placeholders::all);
if(conjugate)
Y_mat = diag_conj(D1) * X_mat + diag_conj(D2) * X_mat(subdiag_ids, Eigen::placeholders::all);
else
Y_mat = D1 * X_mat + D2 * X_mat(subdiag_ids, Eigen::placeholders::all);
#endif
}
......@@ -240,41 +259,44 @@ namespace Faust
if(typeB == 'N')
{
C = B;
multiply(C, 'N');
multiply(C, typeA);
}
else if(typeB == 'T')
{
auto C = B;
C.transpose();
multiply(C, 'N');
multiply(C, typeA);
}
else if(typeB == 'H')
{
auto C = B;
C.adjoint();
multiply(C, 'N');
multiply(C, typeA);
}
else
throw op_except;
C *= alpha;
}
else // beta != 0
{
C *= beta;
MatDense<FPP, Cpu> Bc(B); //copy
if(alpha != FPP(0))
Bc *= alpha;
if(typeB == 'N')
{
multiply(Bc, 'N');
multiply(Bc, typeA);
}
else if(typeB == 'T')
{
Bc.transpose();
multiply(Bc, 'N');
multiply(Bc, typeA);
}
else if(typeB == 'H')
{
Bc.transpose();
multiply(Bc, 'N');
Bc.adjoint();
multiply(Bc, typeA);
}
else
throw op_except;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment