Mentions légales du service

Skip to content
Snippets Groups Projects
Commit af99ab90 authored by hhakim's avatar hhakim
Browse files

Implement TransformHelperButterfly multiplication by MatSparse.

(#275)
parent 15bd4284
Branches
Tags
No related merge requests found
...@@ -45,5 +45,20 @@ int main(int argc, char** argv) ...@@ -45,5 +45,20 @@ int main(int argc, char** argv)
std::cout << "Faust-dense matrix product OK" << std::endl; std::cout << "Faust-dense matrix product OK" << std::endl;
// test multiplying a MatDense
auto spX = MatSparse<FPP, Cpu>::randMat(size, size, .2);
// X->setOnes();
auto refYsp = F->multiply(*spX);
auto testYsp = oF->multiply(*spX);
auto errYsp = testYsp;
errYsp -= refYsp;
assert(errYsp.norm() <= 1e-6);
std::cout << "Faust-sparse matrix product OK" << std::endl;
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }
...@@ -132,6 +132,9 @@ namespace Faust ...@@ -132,6 +132,9 @@ namespace Faust
template<typename FPP,FDevice DEVICE> template<typename FPP,FDevice DEVICE>
class MatBSR; class MatBSR;
template<typename FPP, FDevice DEV> class TransformHelperButterfly;
template<typename FPP> class ButterflyMat;
template<typename FPP> template<typename FPP>
class MatDense<FPP,Cpu> : public MatGeneric<FPP,Cpu> class MatDense<FPP,Cpu> : public MatGeneric<FPP,Cpu>
{ {
...@@ -141,6 +144,8 @@ namespace Faust ...@@ -141,6 +144,8 @@ namespace Faust
friend TransformHelper<FPP,Cpu>; // TODO: limit to needed member functions only friend TransformHelper<FPP,Cpu>; // TODO: limit to needed member functions only
friend Transform<FPP,Cpu>; //TODO: limit to needed member functions only (multiply) friend Transform<FPP,Cpu>; //TODO: limit to needed member functions only (multiply)
friend void MatDiag<FPP>::multiply(MatDense<FPP,Cpu> & M, char opThis) const; friend void MatDiag<FPP>::multiply(MatDense<FPP,Cpu> & M, char opThis) const;
friend TransformHelperButterfly<FPP, Cpu>;
friend ButterflyMat<FPP>;
/// All derived class template of MatDense are considered as friends /// All derived class template of MatDense are considered as friends
template<class,FDevice> friend class MatDense; template<class,FDevice> friend class MatDense;
......
...@@ -97,6 +97,8 @@ namespace Faust ...@@ -97,6 +97,8 @@ namespace Faust
template<typename FPP, FDevice DEVICE, typename FPP2> class GivensFGFTParallel; template<typename FPP, FDevice DEVICE, typename FPP2> class GivensFGFTParallel;
template<typename FPP, FDevice DEVICE, typename FPP2> class GivensFGFTComplex; template<typename FPP, FDevice DEVICE, typename FPP2> class GivensFGFTComplex;
template<typename FPP> class TransformHelperPoly; template<typename FPP> class TransformHelperPoly;
template<typename FPP, FDevice DEV> class TransformHelperButterfly;
template<typename FPP> class ButterflyMat;
//TODO: simplify/remove the friendship by adding/using a public setter to is_ortho //TODO: simplify/remove the friendship by adding/using a public setter to is_ortho
//template<typename FPP> void wht_factors(unsigned int n, std::vector<MatGeneric<FPP,Cpu>*>& factors, const bool, const bool); //template<typename FPP> void wht_factors(unsigned int n, std::vector<MatGeneric<FPP,Cpu>*>& factors, const bool, const bool);
template<typename FPP> template<typename FPP>
...@@ -117,6 +119,8 @@ namespace Faust ...@@ -117,6 +119,8 @@ namespace Faust
friend Transform<FPP,Cpu>; //TODO: limit to needed member functions only (multiply) friend Transform<FPP,Cpu>; //TODO: limit to needed member functions only (multiply)
friend TransformHelper<FPP,Cpu>; // TODO: limit to needed member functions only friend TransformHelper<FPP,Cpu>; // TODO: limit to needed member functions only
friend TransformHelperPoly<FPP>; // TODO: limit to needed member functions only friend TransformHelperPoly<FPP>; // TODO: limit to needed member functions only
friend TransformHelperButterfly<FPP, Cpu>; // TODO: limit to needed member functions only
friend ButterflyMat<FPP>; // TODO: limit to needed member functions only
friend void wht_factors<>(unsigned int n, std::vector<MatGeneric<FPP,Cpu>*>& factors, const bool, const bool); friend void wht_factors<>(unsigned int n, std::vector<MatGeneric<FPP,Cpu>*>& factors, const bool, const bool);
friend class MatDense<FPP,Cpu>; friend class MatDense<FPP,Cpu>;
friend class MatSparse<std::complex<double>, Cpu>; friend class MatSparse<std::complex<double>, Cpu>;
......
...@@ -29,6 +29,7 @@ namespace Faust ...@@ -29,6 +29,7 @@ namespace Faust
Vect<FPP,Cpu> multiply(const FPP* x); Vect<FPP,Cpu> multiply(const FPP* x);
void multiply(const FPP* A, int A_ncols, FPP* C); void multiply(const FPP* A, int A_ncols, FPP* C);
MatDense<FPP, Cpu> multiply(const MatDense<FPP,Cpu> &A); MatDense<FPP, Cpu> multiply(const MatDense<FPP,Cpu> &A);
MatDense<FPP, Cpu> multiply(const MatSparse<FPP,Cpu> &A);
}; };
...@@ -51,6 +52,7 @@ namespace Faust ...@@ -51,6 +52,7 @@ namespace Faust
void multiply(const FPP* A, int A_ncols, FPP* C, size_t size); void multiply(const FPP* A, int A_ncols, FPP* C, size_t size);
MatDense<FPP, Cpu> multiply(const MatDense<FPP,Cpu> &A); MatDense<FPP, Cpu> multiply(const MatDense<FPP,Cpu> &A);
// MatDense<FPP, Cpu> multiply(const MatSparse<FPP,Cpu> &A);
}; };
} }
......
...@@ -115,10 +115,30 @@ namespace Faust ...@@ -115,10 +115,30 @@ namespace Faust
} }
template<typename FPP> template<typename FPP>
MatDense<FPP, Cpu> TransformHelperButterfly<FPP,Cpu>::multiply(const MatDense<FPP,Cpu> &A) MatDense<FPP, Cpu> TransformHelperButterfly<FPP,Cpu>::multiply(const MatDense<FPP,Cpu> &X)
{ {
MatDense<FPP, Cpu> Y(this->getNbRow(), A.getNbCol()); MatDense<FPP, Cpu> Y(this->getNbRow(), X.getNbCol());
multiply(A.getData(), A.getNbCol(), Y.getData()); multiply(X.getData(), X.getNbCol(), Y.getData());
return Y;
}
template<typename FPP>
MatDense<FPP, Cpu> TransformHelperButterfly<FPP,Cpu>::multiply(const MatSparse<FPP,Cpu> &X)
{
MatDense<FPP, Cpu> Y(this->getNbRow(), X.getNbCol());
for(int i=0;i < this->getNbRow(); i ++)
Y.mat.row(i) = X.mat.row(bitrev_perm[i]) * perm_d.getData()[i];
// auto Z = new FPP[this->getNbRow()*X.getNbCol()];
// for(auto fac: opt_factors)
// {
// fac.multiply(Y.getData(), Y.mat.cols(), Z, this->getNbRow());
// memcpy(Y.getData(), Z, sizeof(FPP)*this->getNbRow()*X.getNbCol());
// }
// delete[] Z;
for(auto fac: opt_factors)
{
Y = fac.multiply(Y);
}
return Y; return Y;
} }
...@@ -212,10 +232,10 @@ namespace Faust ...@@ -212,10 +232,10 @@ namespace Faust
} }
template<typename FPP> template<typename FPP>
MatDense<FPP, Cpu> ButterflyMat<FPP>::multiply(const MatDense<FPP,Cpu> &A) MatDense<FPP, Cpu> ButterflyMat<FPP>::multiply(const MatDense<FPP,Cpu> &X)
{ {
MatDense<FPP, Cpu> Y(A.getNbrow(), A.getNbCol()); MatDense<FPP, Cpu> Y(X.getNbRow(), X.getNbCol());
multiply(A.getData(), A.getNbCol(), Y.getData(), A.getNbRow()); multiply(X.getData(), X.getNbCol(), Y.getData(), X.getNbRow());
return Y; return Y;
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment