Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 23edad79 authored by hhakim's avatar hhakim
Browse files

Fix a bug on transpose case of MatButterfly multiplication and impl. conjugate...

Fix a bug on transpose case of MatButterfly multiplication and impl. conjugate and adjoint member functions.
parent 420c548b
No related branches found
No related tags found
No related merge requests found
...@@ -93,9 +93,9 @@ namespace Faust ...@@ -93,9 +93,9 @@ namespace Faust
void MatButterfly<FPP, Cpu>::init_transpose() void MatButterfly<FPP, Cpu>::init_transpose()
{ {
//TODO: simplify in case of symmetric matrix (it happens for the FFT) //TODO: simplify in case of symmetric matrix (it happens for the FFT)
auto size = D2.rows();
if(D2T.size() == 0) if(D2T.size() == 0)
{ {
auto size = D2.rows();
FPP *d2_ptr, *d2t_ptr; FPP *d2_ptr, *d2t_ptr;
d2_ptr = D2.diagonal().data(); d2_ptr = D2.diagonal().data();
D2T.resize(size); D2T.resize(size);
...@@ -139,7 +139,6 @@ namespace Faust ...@@ -139,7 +139,6 @@ namespace Faust
template<typename FPP> template<typename FPP>
void MatButterfly<FPP, Cpu>::multiply(MatDense<FPP,Cpu> & M, char opThis) const void MatButterfly<FPP, Cpu>::multiply(MatDense<FPP,Cpu> & M, char opThis) const
{ {
MatDense<FPP, Cpu> Y(this->getNbRow(), M.getNbCol()); MatDense<FPP, Cpu> Y(this->getNbRow(), M.getNbCol());
const_cast<MatButterfly<FPP, Cpu>*>(this)->multiply(M.getData(), M.getNbCol(), Y.getData(), Y.getNbRow(), 'T' == opThis); const_cast<MatButterfly<FPP, Cpu>*>(this)->multiply(M.getData(), M.getNbCol(), Y.getData(), Y.getNbRow(), 'T' == opThis);
M = Y; M = Y;
...@@ -166,11 +165,11 @@ namespace Faust ...@@ -166,11 +165,11 @@ namespace Faust
DiagMat &D2 = this->D2; DiagMat &D2 = this->D2;
const FPP *d1_ptr, *d2_ptr; const FPP *d1_ptr, *d2_ptr;
d1_ptr = D1.diagonal().data(); d1_ptr = D1.diagonal().data();
if(transpose ^ is_transp) auto do_transp = transpose ^ is_transp;
if(do_transp)
{ {
init_transpose(); // no cost if already initialized init_transpose(); // no cost if already initialized
d2_ptr = D2T.diagonal().data(); d2_ptr = D2T.diagonal().data();
D2 = D2T;
} }
else else
d2_ptr = D2.diagonal().data(); d2_ptr = D2.diagonal().data();
...@@ -193,7 +192,10 @@ namespace Faust ...@@ -193,7 +192,10 @@ namespace Faust
// this is slower // this is slower
VecMap x_vec(const_cast<FPP*>(x), size); // const_cast is harmless VecMap x_vec(const_cast<FPP*>(x), size); // const_cast is harmless
VecMap y_vec(y, size); // const_cast is harmless VecMap y_vec(y, size); // const_cast is harmless
y_vec = D1 * x_vec + D2 * x_vec(subdiag_ids); if(do_transp)
y_vec = D1 * x_vec + D2T * x_vec(subdiag_ids);
else
y_vec = D1 * x_vec + D2 * x_vec(subdiag_ids);
#endif #endif
} }
...@@ -204,14 +206,13 @@ namespace Faust ...@@ -204,14 +206,13 @@ namespace Faust
MatMap X_mat(const_cast<FPP*>(X) /* harmless, no modification*/, Y_nrows, X_ncols); MatMap X_mat(const_cast<FPP*>(X) /* harmless, no modification*/, Y_nrows, X_ncols);
MatMap Y_mat(Y, Y_nrows, X_ncols); MatMap Y_mat(Y, Y_nrows, X_ncols);
// --------- TODO refactor this block with overload of multiply // --------- TODO refactor this block with overload of multiply
DiagMat &D2 = this->D2;
const FPP *d1_ptr, *d2_ptr; const FPP *d1_ptr, *d2_ptr;
d1_ptr = D1.diagonal().data(); d1_ptr = D1.diagonal().data();
if(transpose ^ is_transp) auto do_transp = transpose ^ is_transp;
if(do_transp)
{ {
init_transpose(); // no cost if already initialized init_transpose(); // no cost if already initialized
d2_ptr = D2T.diagonal().data(); d2_ptr = D2T.diagonal().data();
D2 = D2T;
} }
else else
d2_ptr = D2.diagonal().data(); d2_ptr = D2.diagonal().data();
...@@ -222,7 +223,10 @@ namespace Faust ...@@ -222,7 +223,10 @@ namespace Faust
for(int i=0;i < Y_nrows; i++) for(int i=0;i < Y_nrows; i++)
Y_mat.row(i) = d1_ptr[i] * X_mat.row(i) + d2_ptr[i] * X_mat.row(subdiag_ids[i]); Y_mat.row(i) = d1_ptr[i] * X_mat.row(i) + d2_ptr[i] * X_mat.row(subdiag_ids[i]);
#else #else
Y_mat = D1 * X_mat + D2 * X_mat(subdiag_ids, Eigen::placeholders::all); if(do_transp)
Y_mat = D1 * X_mat + D2T * X_mat(subdiag_ids, Eigen::placeholders::all);
else
Y_mat = D1 * X_mat + D2 * X_mat(subdiag_ids, Eigen::placeholders::all);
#endif #endif
} }
...@@ -244,13 +248,18 @@ namespace Faust ...@@ -244,13 +248,18 @@ namespace Faust
template<typename FPP> template<typename FPP>
void MatButterfly<FPP, Cpu>::conjugate(const bool eval) void MatButterfly<FPP, Cpu>::conjugate(const bool eval)
{ {
//TODO auto size = getNbRow();
VecMap d1_vec(const_cast<FPP*>(D1.diagonal().data()), size); // const_cast is harmless
VecMap d2_vec(const_cast<FPP*>(D2.diagonal().data()), size); // const_cast is harmless
d1_vec = d1_vec.conjugate();
d2_vec = d2_vec.conjugate();
} }
template<typename FPP> template<typename FPP>
void MatButterfly<FPP, Cpu>::adjoint() void MatButterfly<FPP, Cpu>::adjoint()
{ {
//TODO transpose();
conjugate();
} }
template<typename FPP> template<typename FPP>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment