Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 1211fb8b authored by hhakim's avatar hhakim
Browse files

Implement GPU Butterfly Faust optimization in the C++ core.

parent 4919663d
No related branches found
No related tags found
No related merge requests found
...@@ -133,7 +133,7 @@ namespace Faust ...@@ -133,7 +133,7 @@ namespace Faust
class MatBSR; class MatBSR;
template<typename FPP, FDevice DEV> class TransformHelperButterfly; template<typename FPP, FDevice DEV> class TransformHelperButterfly;
template<typename FPP> class ButterflyMat; template<typename FPP, FDevice DEV> class ButterflyMat;
template<typename FPP> template<typename FPP>
class MatDense<FPP,Cpu> : public MatGeneric<FPP,Cpu> class MatDense<FPP,Cpu> : public MatGeneric<FPP,Cpu>
...@@ -145,7 +145,7 @@ namespace Faust ...@@ -145,7 +145,7 @@ namespace Faust
friend Transform<FPP,Cpu>; //TODO: limit to needed member functions only (multiply) friend Transform<FPP,Cpu>; //TODO: limit to needed member functions only (multiply)
friend void MatDiag<FPP>::multiply(MatDense<FPP,Cpu> & M, char opThis) const; friend void MatDiag<FPP>::multiply(MatDense<FPP,Cpu> & M, char opThis) const;
friend TransformHelperButterfly<FPP, Cpu>; friend TransformHelperButterfly<FPP, Cpu>;
friend ButterflyMat<FPP>; friend ButterflyMat<FPP, Cpu>;
/// All derived class template of MatDense are considered as friends /// All derived class template of MatDense are considered as friends
template<class,FDevice> friend class MatDense; template<class,FDevice> friend class MatDense;
......
...@@ -98,7 +98,7 @@ namespace Faust ...@@ -98,7 +98,7 @@ namespace Faust
template<typename FPP, FDevice DEVICE, typename FPP2> class GivensFGFTComplex; template<typename FPP, FDevice DEVICE, typename FPP2> class GivensFGFTComplex;
template<typename FPP> class TransformHelperPoly; template<typename FPP> class TransformHelperPoly;
template<typename FPP, FDevice DEV> class TransformHelperButterfly; template<typename FPP, FDevice DEV> class TransformHelperButterfly;
template<typename FPP> class ButterflyMat; template<typename FPP, FDevice DEV> class ButterflyMat;
//TODO: simplify/remove the friendship by adding/using a public setter to is_ortho //TODO: simplify/remove the friendship by adding/using a public setter to is_ortho
//template<typename FPP> void wht_factors(unsigned int n, std::vector<MatGeneric<FPP,Cpu>*>& factors, const bool, const bool); //template<typename FPP> void wht_factors(unsigned int n, std::vector<MatGeneric<FPP,Cpu>*>& factors, const bool, const bool);
template<typename FPP> template<typename FPP>
...@@ -120,7 +120,7 @@ namespace Faust ...@@ -120,7 +120,7 @@ namespace Faust
friend TransformHelper<FPP,Cpu>; // TODO: limit to needed member functions only friend TransformHelper<FPP,Cpu>; // TODO: limit to needed member functions only
friend TransformHelperPoly<FPP>; // TODO: limit to needed member functions only friend TransformHelperPoly<FPP>; // TODO: limit to needed member functions only
friend TransformHelperButterfly<FPP, Cpu>; // TODO: limit to needed member functions only friend TransformHelperButterfly<FPP, Cpu>; // TODO: limit to needed member functions only
friend ButterflyMat<FPP>; // TODO: limit to needed member functions only friend ButterflyMat<FPP, Cpu>; // TODO: limit to needed member functions only
friend void wht_factors<>(unsigned int n, std::vector<MatGeneric<FPP,Cpu>*>& factors, const bool, const bool); friend void wht_factors<>(unsigned int n, std::vector<MatGeneric<FPP,Cpu>*>& factors, const bool, const bool);
friend class MatDense<FPP,Cpu>; friend class MatDense<FPP,Cpu>;
friend class MatSparse<std::complex<double>, Cpu>; friend class MatSparse<std::complex<double>, Cpu>;
......
...@@ -26,7 +26,7 @@ namespace Faust ...@@ -26,7 +26,7 @@ namespace Faust
template<typename FPP, FDevice DEV> template<typename FPP, FDevice DEV>
class TransformHelperButterfly; class TransformHelperButterfly;
template<typename FPP> template<typename FPP, FDevice DEV>
class ButterflyMat; class ButterflyMat;
template<typename FPP> template<typename FPP>
...@@ -38,7 +38,7 @@ namespace Faust ...@@ -38,7 +38,7 @@ namespace Faust
FPP *perm_d_ptr; FPP *perm_d_ptr;
DiagMat D; DiagMat D;
std::vector<unsigned int> bitrev_perm; std::vector<unsigned int> bitrev_perm;
std::vector<ButterflyMat<FPP>> opt_factors; std::vector<ButterflyMat<FPP, Cpu>> opt_factors;
// private ctor // private ctor
...@@ -59,7 +59,7 @@ namespace Faust ...@@ -59,7 +59,7 @@ namespace Faust
}; };
template<typename FPP> template<typename FPP>
class ButterflyMat class ButterflyMat<FPP, Cpu>
{ {
using VecMap = Eigen::Map<Eigen::Matrix<FPP, Eigen::Dynamic, 1>>; using VecMap = Eigen::Map<Eigen::Matrix<FPP, Eigen::Dynamic, 1>>;
...@@ -74,7 +74,7 @@ namespace Faust ...@@ -74,7 +74,7 @@ namespace Faust
// \param level: is a 0-base index. // \param level: is a 0-base index.
public: public:
ButterflyMat<FPP>(const MatSparse<FPP, Cpu> &factor, int level); ButterflyMat<FPP, Cpu>(const MatSparse<FPP, Cpu> &factor, int level);
Vect<FPP, Cpu> multiply(const Vect<FPP, Cpu>& x) const; Vect<FPP, Cpu> multiply(const Vect<FPP, Cpu>& x) const;
void multiply(const FPP* x, FPP* y, size_t size) const; void multiply(const FPP* x, FPP* y, size_t size) const;
...@@ -86,6 +86,7 @@ namespace Faust ...@@ -86,6 +86,7 @@ namespace Faust
public: public:
const DiagMat& getD1() {return D1;}; const DiagMat& getD1() {return D1;};
const DiagMat& getD2() {return D2;}; const DiagMat& getD2() {return D2;};
const std::vector<int>& get_subdiag_ids() {return subdiag_ids;}
}; };
} }
#include "faust_TransformHelperButterfly.hpp" #include "faust_TransformHelperButterfly.hpp"
......
...@@ -15,7 +15,7 @@ namespace Faust ...@@ -15,7 +15,7 @@ namespace Faust
for(auto csr_fac_it = this->begin(); csr_fac_it != end_it; csr_fac_it++) for(auto csr_fac_it = this->begin(); csr_fac_it != end_it; csr_fac_it++)
{ {
auto csr_fac = *csr_fac_it; auto csr_fac = *csr_fac_it;
opt_factors.insert(opt_factors.begin(), ButterflyMat<FPP>(*dynamic_cast<const MatSparse<FPP, Cpu>*>(csr_fac), i++)); opt_factors.insert(opt_factors.begin(), ButterflyMat<FPP, Cpu>(*dynamic_cast<const MatSparse<FPP, Cpu>*>(csr_fac), i++));
} }
if(has_permutation) if(has_permutation)
{ {
...@@ -119,7 +119,7 @@ namespace Faust ...@@ -119,7 +119,7 @@ namespace Faust
} }
else else
{ {
ButterflyMat<FPP>& fac = opt_factors[0]; ButterflyMat<FPP, Cpu>& fac = opt_factors[0];
fac.multiply(x, z.getData(), this->getNbRow()); fac.multiply(x, z.getData(), this->getNbRow());
i = 1; i = 1;
} }
...@@ -127,7 +127,7 @@ namespace Faust ...@@ -127,7 +127,7 @@ namespace Faust
while(i < opt_factors.size()) while(i < opt_factors.size())
// for(auto fac: opt_factors) // for(auto fac: opt_factors)
{ {
ButterflyMat<FPP>& fac = opt_factors[i]; ButterflyMat<FPP, Cpu>& fac = opt_factors[i];
if(i & 1) if(i & 1)
fac.multiply(z.getData(), y, this->getNbRow()); fac.multiply(z.getData(), y, this->getNbRow());
else else
...@@ -167,7 +167,7 @@ namespace Faust ...@@ -167,7 +167,7 @@ namespace Faust
} }
else else
{ {
ButterflyMat<FPP>& fac = opt_factors[0]; ButterflyMat<FPP, Cpu>& fac = opt_factors[0];
fac.multiply(X, X_mat.cols(), Z, this->getNbRow()); fac.multiply(X, X_mat.cols(), Z, this->getNbRow());
i = 1; i = 1;
} }
...@@ -175,7 +175,7 @@ namespace Faust ...@@ -175,7 +175,7 @@ namespace Faust
while(i < opt_factors.size()) while(i < opt_factors.size())
// for(auto fac: opt_factors) // for(auto fac: opt_factors)
{ {
ButterflyMat<FPP>& fac = opt_factors[i]; ButterflyMat<FPP, Cpu>& fac = opt_factors[i];
if(i & 1) if(i & 1)
fac.multiply(Z, Y_mat.cols(), Y, this->getNbRow()); fac.multiply(Z, Y_mat.cols(), Y, this->getNbRow());
...@@ -218,7 +218,7 @@ namespace Faust ...@@ -218,7 +218,7 @@ namespace Faust
// for(auto fac: opt_factors) // for(auto fac: opt_factors)
while(i < opt_factors.size()) while(i < opt_factors.size())
{ {
ButterflyMat<FPP>& fac = opt_factors[i]; ButterflyMat<FPP, Cpu>& fac = opt_factors[i];
Y = fac.multiply(Y); Y = fac.multiply(Y);
i++; i++;
} }
...@@ -227,7 +227,7 @@ namespace Faust ...@@ -227,7 +227,7 @@ namespace Faust
//TODO: factorize with MatDense product //TODO: factorize with MatDense product
while(i < opt_factors.size()) while(i < opt_factors.size())
{ {
ButterflyMat<FPP>& fac = opt_factors[i]; ButterflyMat<FPP, Cpu>& fac = opt_factors[i];
if(i & 1) if(i & 1)
fac.multiply(Z, Y.mat.cols(), Y.getData(), this->getNbRow()); fac.multiply(Z, Y.mat.cols(), Y.getData(), this->getNbRow());
else else
...@@ -248,7 +248,7 @@ namespace Faust ...@@ -248,7 +248,7 @@ namespace Faust
namespace Faust namespace Faust
{ {
template<typename FPP> template<typename FPP>
ButterflyMat<FPP>::ButterflyMat(const MatSparse<FPP, Cpu> &factor, int level) ButterflyMat<FPP, Cpu>::ButterflyMat(const MatSparse<FPP, Cpu> &factor, int level)
{ {
// build a d1, d2 pair from the butterfly factor // build a d1, d2 pair from the butterfly factor
auto size = factor.getNbRow(); auto size = factor.getNbRow();
...@@ -291,8 +291,9 @@ namespace Faust ...@@ -291,8 +291,9 @@ namespace Faust
} }
template<typename FPP> template<typename FPP>
void ButterflyMat<FPP>::Display() const void ButterflyMat<FPP, Cpu>::Display() const
{ {
std::cout << "ButterflyMat on CPU: ";
std::cout << "D1: "; std::cout << "D1: ";
std::cout << D1.diagonal() << std::endl; std::cout << D1.diagonal() << std::endl;
std::cout << "D2: "; std::cout << "D2: ";
...@@ -304,7 +305,7 @@ namespace Faust ...@@ -304,7 +305,7 @@ namespace Faust
} }
template<typename FPP> template<typename FPP>
Vect<FPP, Cpu> ButterflyMat<FPP>::multiply(const Vect<FPP, Cpu>& x) const Vect<FPP, Cpu> ButterflyMat<FPP, Cpu>::multiply(const Vect<FPP, Cpu>& x) const
{ {
Vect<FPP, Cpu> z(x.size()); Vect<FPP, Cpu> z(x.size());
multiply(x.getData(), z.getData(), x.size()); multiply(x.getData(), z.getData(), x.size());
...@@ -312,7 +313,7 @@ namespace Faust ...@@ -312,7 +313,7 @@ namespace Faust
} }
template<typename FPP> template<typename FPP>
void ButterflyMat<FPP>::multiply(const FPP* x, FPP* y, size_t size) const void ButterflyMat<FPP, Cpu>::multiply(const FPP* x, FPP* y, size_t size) const
{ {
const FPP *d1_ptr = D1.diagonal().data(), *d2_ptr = D2.diagonal().data(); const FPP *d1_ptr = D1.diagonal().data(), *d2_ptr = D2.diagonal().data();
#ifdef USE_PYTHONIC #ifdef USE_PYTHONIC
...@@ -339,7 +340,7 @@ namespace Faust ...@@ -339,7 +340,7 @@ namespace Faust
} }
template<typename FPP> template<typename FPP>
void ButterflyMat<FPP>::multiply(const FPP* X, int X_ncols, FPP* Y, size_t Y_nrows) void ButterflyMat<FPP, Cpu>::multiply(const FPP* X, int X_ncols, FPP* Y, size_t Y_nrows)
{ {
using MatMap = Eigen::Map<Eigen::Matrix<FPP, Eigen::Dynamic, Eigen::Dynamic>>; using MatMap = Eigen::Map<Eigen::Matrix<FPP, Eigen::Dynamic, Eigen::Dynamic>>;
MatMap X_mat(const_cast<FPP*>(X) /* harmless, no modification*/, Y_nrows, X_ncols); MatMap X_mat(const_cast<FPP*>(X) /* harmless, no modification*/, Y_nrows, X_ncols);
...@@ -356,7 +357,7 @@ namespace Faust ...@@ -356,7 +357,7 @@ namespace Faust
} }
template<typename FPP> template<typename FPP>
MatDense<FPP, Cpu> ButterflyMat<FPP>::multiply(const MatDense<FPP,Cpu> &X) MatDense<FPP, Cpu> ButterflyMat<FPP, Cpu>::multiply(const MatDense<FPP,Cpu> &X)
{ {
MatDense<FPP, Cpu> Y(X.getNbRow(), X.getNbCol()); MatDense<FPP, Cpu> Y(X.getNbRow(), X.getNbCol());
multiply(X.getData(), X.getNbCol(), Y.getData(), X.getNbRow()); multiply(X.getData(), X.getNbCol(), Y.getData(), X.getNbRow());
......
#ifndef __FAUST_TRANSFORM_HELPER_DFT_GPU2__ #ifndef __FAUST_TRANSFORM_HELPER_DFT_GPU2__
#define __FAUST_TRANSFORM_HELPER_DFT_GPU2__ #define __FAUST_TRANSFORM_HELPER_DFT_GPU2__
#ifdef USE_GPU_MOD
#include "faust_TransformHelper_gpu.h" #include "faust_TransformHelper_gpu.h"
namespace Faust namespace Faust
...@@ -7,36 +8,64 @@ namespace Faust ...@@ -7,36 +8,64 @@ namespace Faust
template<typename FPP, FDevice DEV> template<typename FPP, FDevice DEV>
class TransformHelperButterfly; class TransformHelperButterfly;
template<typename FPP> template<typename FPP, FDevice DEV>
class ButterflyMat; class ButterflyMat;
template<typename FPP> template<typename FPP>
class TransformHelperButterfly<FPP, GPU2> : public TransformHelper<FPP, GPU2> class TransformHelperButterfly<FPP, GPU2> : public TransformHelper<FPP, GPU2>
{ {
// using VecMap = Eigen::Map<Eigen::Matrix<FPP, Eigen::Dynamic, 1>>; int* perm_ids;
// using DiagMat = Eigen::DiagonalMatrix<FPP, Eigen::Dynamic>; Vect<FPP, GPU2> d_perm;
// FPP *perm_d_ptr; bool has_permutation;
// DiagMat D; std::vector<ButterflyMat<FPP, GPU2>> opt_factors;
// std::vector<unsigned int> bitrev_perm;
// std::vector<ButterflyMat<FPP>> opt_factors;
//
//
// // private ctor
// TransformHelperButterfly<FPP, GPU2>(const std::vector<MatGeneric<FPP,GPU2> *>& facts, const FPP lambda_ = (FPP)1.0, const bool optimizedCopy=false, const bool cloning_fact = true, const bool internal_call=false);
//
// private ctor
TransformHelperButterfly(const std::vector<MatGeneric<FPP, Cpu> *>& facts, const FPP lambda_ = (FPP)1.0, const bool optimizedCopy=false, const bool cloning_fact = true, const bool internal_call=false);
~TransformHelperButterfly() { delete[] perm_ids;}
public: public:
static TransformHelper<FPP,GPU2>* fourierFaust(unsigned int n, const bool norma=true) { throw std::runtime_error("Not yet implemented on GPU");}; static TransformHelper<FPP,GPU2>* fourierFaust(unsigned int n, const bool norma=true);
static TransformHelper<FPP,GPU2>* optFaust(const TransformHelper<FPP, GPU2>* F) { throw std::runtime_error("Not yet implemented on GPU");}; static TransformHelper<FPP,GPU2>* optFaust(const TransformHelper<FPP, GPU2>* F) { throw std::runtime_error("Not yet implemented on GPU");};
// Vect<FPP, GPU2> multiply(const Vect<FPP, GPU2>& x); Vect<FPP, Cpu> multiply(const Vect<FPP, Cpu>& x);
// void multiply(const FPP* x, FPP* y); void multiply(const FPP* x, FPP* y);
// Vect<FPP,GPU2> multiply(const FPP* x); Vect<FPP,Cpu> multiply(const FPP* x);
// void multiply(const FPP* A, int A_ncols, FPP* C); void multiply(const FPP* A, int A_ncols, FPP* C);
// MatDense<FPP, GPU2> multiply(const MatDense<FPP,GPU2> &A); MatDense<FPP, Cpu> multiply(const MatDense<FPP,Cpu> &A);
// MatDense<FPP, GPU2> multiply(const MatSparse<FPP,GPU2> &A); MatDense<FPP, Cpu> multiply(const MatSparse<FPP,Cpu> &A);
}; };
template<typename FPP>
class ButterflyMat<FPP, GPU2>
{
Vect<FPP, GPU2> d1;
Vect<FPP, GPU2> d2;
int* subdiag_ids;
#ifdef USE_PYTHONIC
long *subdiag_ids_ptr;
#endif
int level;
// \param level: is a 0-base index.
public:
ButterflyMat(const MatSparse<FPP, Cpu> &factor, int level);
//TODO: constness of multiply member functions
MatDense<FPP, GPU2> multiply(const FPP* x);
void Display() const;
MatDense<FPP, GPU2> multiply(const FPP* A, int A_ncols);
MatDense<FPP, GPU2> multiply(MatDense<FPP,GPU2> &A);
void multiply(MatDense<FPP,GPU2> &A, MatDense<FPP, Cpu> & out);
// MatDense<FPP, Cpu> multiply(const MatSparse<FPP,Cpu> &A);
const Vect<FPP, GPU2>& getD1() {return d1;};
const Vect<FPP, GPU2>& getD2() {return d2;};
~ButterflyMat() { delete[] subdiag_ids;}
};
} }
//#include "faust_TransformHelperButterfly_gpu.hpp" //TODO #include "faust_TransformHelperButterfly_gpu.hpp" //TODO
#endif
#endif #endif
#include "faust_TransformHelperButterfly.h"
namespace Faust
{
template<typename FPP>
TransformHelperButterfly<FPP, GPU2>::TransformHelperButterfly(const std::vector<MatGeneric<FPP,Cpu> *>& facts, const FPP lambda_ /*= (FPP)1.0*/, const bool optimizedCopy/*=false*/, const bool cloning_fact /*= true*/, const bool internal_call/*=false*/)
{
int i = 0;
auto size = this->getNbRow();
// for(auto csr_fac: facts)
// use rather recorded factors in the Faust::Transform because one might have been multiplied with lambda_
auto log2nf = 1 << (this->size() - 1);
has_permutation = (log2nf - this->getNbRow()) == 0;
auto end_it = has_permutation?this->end()-1:this->end();
for(auto csr_fac_it = this->begin(); csr_fac_it != end_it; csr_fac_it++)
{
auto csr_fac = *csr_fac_it;
opt_factors.insert(opt_factors.begin(), ButterflyMat<FPP, GPU2>(*dynamic_cast<const MatSparse<FPP, Cpu>*>(csr_fac), i++));
this->push_back(csr_fac);
}
if(has_permutation)
{
// set the permutation factor
auto csr_fac = dynamic_cast<const MatSparse<FPP, Cpu>*>(*(this->end()-1));
this->push_back(csr_fac);
d_perm.resize(size);
// only ones should be enough because this is a permutation matrix but it could be normalized
d_perm = Vect<FPP, GPU2>(size, csr_fac->getValuePtr());
perm_ids = new int[size];
copy(csr_fac->getColInd(), csr_fac->getColInd()+size, perm_ids);
}
}
template<typename FPP>
void TransformHelperButterfly<FPP, GPU2>::multiply(const FPP* A, int A_ncols, FPP* C)
{
MatDense<FPP, GPU2> gpu_X(d_perm.size(), A_ncols, A);
if(has_permutation)
gpu_X.eltwise_mul(d_perm, perm_ids);
for(auto gpu_bmat: opt_factors)
gpu_bmat.multiply(gpu_X);
gpu_X.tocpu(C);
}
template<typename FPP>
Vect<FPP, Cpu> TransformHelperButterfly<FPP, GPU2>::multiply(const Vect<FPP, Cpu>& x)
{
Vect<FPP, Cpu> y;
y.resize(d_perm.size());
multiply(x.getData(), y.getData());
return y;
}
template<typename FPP>
void TransformHelperButterfly<FPP, GPU2>::multiply(const FPP* x, FPP* y)
{
multiply(x, 1, y);
}
template<typename FPP>
Vect<FPP, Cpu> TransformHelperButterfly<FPP, GPU2>::multiply(const FPP* x)
{
Vect<FPP, Cpu> y;
y.resize(d_perm.size());
multiply(x, 1, y.getData());
return y;
}
template<typename FPP>
MatDense<FPP, Cpu> TransformHelperButterfly<FPP, GPU2>::multiply(const MatDense<FPP,Cpu> &A)
{
MatDense<FPP, Cpu> out;
out.resize(d_perm.size(), A.getNbCol());
multiply(A.getData(), A.getNbCol(), out.getData());
return out;
}
template<typename FPP>
MatDense<FPP, Cpu> TransformHelperButterfly<FPP,GPU2>::multiply(const MatSparse<FPP,Cpu> &X)
{
return multiply(MatDense<FPP, GPU2>(X));
}
template<typename FPP>
TransformHelper<FPP,GPU2>* TransformHelperButterfly<FPP,GPU2>::fourierFaust(unsigned int n, const bool norma/*=true*/)
{
std::vector<MatGeneric<FPP,Cpu>*> factors(n+1);
TransformHelper<FPP, GPU2>* fourierFaust = nullptr;
try
{
fft_factors(n, factors);
FPP alpha = norma?FPP(1/sqrt((double)(1 << n))):FPP(1.0);
fourierFaust = new TransformHelperButterfly<FPP, GPU2>(factors, alpha, false, false, /* internal call */ true);
}
catch(std::bad_alloc e)
{
//nothing to do, out of memory, return nullptr
}
return fourierFaust;
}
template<typename FPP>
ButterflyMat<FPP, GPU2>::ButterflyMat(const MatSparse<FPP, Cpu> &factor, int level) : level(level)
{
ButterflyMat<FPP, Cpu> cpu_bmat(factor, level);
auto cpu_d1 = cpu_bmat.getD1();
auto cpu_d2 = cpu_bmat.getD2();
d1 = Vect<FPP, GPU2>(cpu_d1.size(), cpu_d1.diagonal().data());
d2 = Vect<FPP, GPU2>(cpu_d2.size(), cpu_d2.diagonal().data());
auto sd_ids_vec = cpu_bmat.get_subdiag_ids();
subdiag_ids = new int[sd_ids_vec.size()];
memcpy(subdiag_ids, sd_ids_vec.data(), sizeof(int) * sd_ids_vec.size());
}
template<typename FPP>
void ButterflyMat<FPP, GPU2>::Display() const
{
std::cout << "ButterflyMat on GPU: ";
std::cout << "D1: ";
d1.Display();
std::cout << "D2: ";
d1.Display();
cout << "subdiag_ids: ";
for(int i=0;i < d1.size();i++)
cout << subdiag_ids[i] << " ";
cout << std::endl;
}
template<typename FPP>
MatDense<FPP, GPU2> ButterflyMat<FPP, GPU2>::multiply(const FPP* X, int X_ncols)
{
MatDense<FPP, GPU2> gpu_X(d1.size(), X_ncols, X);
return multiply(gpu_X);
}
template<typename FPP>
MatDense<FPP, GPU2> ButterflyMat<FPP, GPU2>::multiply(const FPP* x)
{
return multiply(x, 1);
}
template<typename FPP>
MatDense<FPP, GPU2> ButterflyMat<FPP, GPU2>::multiply(MatDense<FPP, GPU2> &gpu_X)
{
MatDense<FPP, GPU2> gpu_X2(gpu_X);
gpu_X.eltwise_mul(d2, subdiag_ids);
gpu_X2.eltwise_mul(d1);
gpu_X += gpu_X2;
return gpu_X;
}
template<typename FPP>
void ButterflyMat<FPP, GPU2>::multiply(MatDense<FPP, GPU2> &gpu_X, MatDense<FPP, Cpu> &cpu_out)
{
multiply(gpu_X);
gpu_X.tocpu(cpu_out);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment