Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 15bd4284 authored by hhakim's avatar hhakim
Browse files

Add multiply(MatDense) to TransformHelperButterfly.

(#275)
parent d7a79488
Branches
Tags
No related merge requests found
......@@ -5,12 +5,16 @@ typedef @TEST_FPP@ FPP;
using namespace Faust;
int main()
int main(int argc, char** argv)
{
int log2size = 5;
int log2size = 4;
if(argc > 1)
log2size = std::atoi(argv[1]);
std::cout << "log2size: " << log2size << std::endl;
int size = 1 << log2size;
auto oF = TransformHelperButterfly<FPP, Cpu>::fourierFaust(5, false);
auto F = TransformHelper<FPP, Cpu>::fourierFaust(5, false);
auto oF = TransformHelperButterfly<FPP, Cpu>::fourierFaust(log2size, false);
auto F = TransformHelper<FPP, Cpu>::fourierFaust(log2size, false);
Vect<FPP, Cpu> x(size);
x.setRand();
......@@ -24,6 +28,22 @@ int main()
auto err = test_v;
err -= ref_v;
assert(err.norm() <= 1e-6);
std::cout << "Faust-vector product OK" << std::endl;
// test multiplying a MatDense
auto X = MatDense<FPP, Cpu>::randMat(size, size);
// X->setOnes();
auto refY = F->multiply(*X);
auto testY = oF->multiply(*X);
auto errY = testY;
errY -= refY;
assert(errY.norm() <= 1e-6);
std::cout << "Faust-dense matrix product OK" << std::endl;
return EXIT_SUCCESS;
}
......@@ -110,7 +110,6 @@ namespace Faust
virtual void multiply(const FPP* A, int A_ncols, FPP* C);
// MatDense<FPP,Cpu> multiply(const MatDense<FPP,Cpu> A) const;
virtual MatDense<FPP, Cpu> multiply(const MatDense<FPP,Cpu> &A);
virtual void update_total_nnz();
virtual MatDense<FPP, Cpu> multiply(const MatSparse<FPP,Cpu> &A);
virtual TransformHelper<FPP, Cpu>* multiply(const TransformHelper<FPP, Cpu>*) const;
......@@ -130,6 +129,7 @@ namespace Faust
void push_first(const MatGeneric<FPP,Cpu>* M, const bool optimizedCopy=false, const bool copying=true);
virtual faust_unsigned_int getNBytes() const;
virtual faust_unsigned_int get_total_nnz() const;
virtual void update_total_nnz();
bool is_zero() const;
faust_unsigned_int size() const;
virtual void resize(faust_unsigned_int);
......
......@@ -27,6 +27,9 @@ namespace Faust
Vect<FPP, Cpu> multiply(const Vect<FPP, Cpu>& x);
void multiply(const FPP* x, FPP* y);
Vect<FPP,Cpu> multiply(const FPP* x);
void multiply(const FPP* A, int A_ncols, FPP* C);
MatDense<FPP, Cpu> multiply(const MatDense<FPP,Cpu> &A);
};
template<typename FPP>
......@@ -46,6 +49,9 @@ namespace Faust
void multiply(const FPP* x, FPP* y, size_t size) const;
void Display() const;
void multiply(const FPP* A, int A_ncols, FPP* C, size_t size);
MatDense<FPP, Cpu> multiply(const MatDense<FPP,Cpu> &A);
};
}
#include "faust_TransformHelperButterfly.hpp"
......
......@@ -96,6 +96,32 @@ namespace Faust
multiply(x, y.getData());
return y;
}
template<typename FPP>
void TransformHelperButterfly<FPP,Cpu>::multiply(const FPP* X, int X_ncols, FPP* Y)
{
using MatMap = Eigen::Map<Eigen::Matrix<FPP, Eigen::Dynamic, Eigen::Dynamic>>;
MatMap X_mat(const_cast<FPP*>(X) /* harmless, no modification*/, this->getNbCol(), X_ncols);
MatMap Y_mat(Y, this->getNbRow(), X_ncols);
for(int i=0;i < this->getNbRow(); i ++)
Y_mat.row(i) = X_mat.row(bitrev_perm[i]) * perm_d.getData()[i];
auto Z = new FPP[this->getNbRow()*X_ncols];
for(auto fac: opt_factors)
{
fac.multiply(Y, Y_mat.cols(), Z, this->getNbRow());
memcpy(Y, Z, sizeof(FPP)*this->getNbRow()*X_ncols);
}
delete[] Z;
}
template<typename FPP>
MatDense<FPP, Cpu> TransformHelperButterfly<FPP,Cpu>::multiply(const MatDense<FPP,Cpu> &A)
{
MatDense<FPP, Cpu> Y(this->getNbRow(), A.getNbCol());
multiply(A.getData(), A.getNbCol(), Y.getData());
return Y;
}
}
......@@ -173,4 +199,23 @@ namespace Faust
for(int i=0;i < size; i++)
y[i] = d1_ptr[i] * x[i] + d2_ptr[i] * x[subdiag_ids[i]];
}
template<typename FPP>
void ButterflyMat<FPP>::multiply(const FPP* X, int X_ncols, FPP* Y, size_t Y_nrows)
{
using MatMap = Eigen::Map<Eigen::Matrix<FPP, Eigen::Dynamic, Eigen::Dynamic>>;
MatMap X_mat(const_cast<FPP*>(X) /* harmless, no modification*/, Y_nrows, X_ncols);
MatMap Y_mat(Y, Y_nrows, X_ncols);
const FPP *d1_ptr = d1.getData(), *d2_ptr = d2.getData();
for(int i=0;i < Y_nrows; i++)
Y_mat.row(i) = d1_ptr[i] * X_mat.row(i) + d2_ptr[i] * X_mat.row(subdiag_ids[i]);
}
template<typename FPP>
MatDense<FPP, Cpu> ButterflyMat<FPP>::multiply(const MatDense<FPP,Cpu> &A)
{
MatDense<FPP, Cpu> Y(A.getNbrow(), A.getNbCol());
multiply(A.getData(), A.getNbCol(), Y.getData(), A.getNbRow());
return Y;
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment