Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 4fed8d59 authored by hhakim's avatar hhakim
Browse files

Fix multiple errors in TransformHelperButterfly<FPP, GPU2> (wrong buffer size,...

Fix multiple errors in TransformHelperButterfly<FPP, GPU2> (wrong buffer size, forgetting the Faust scale factor, and other minor errors).
parent 435a5792
Branches
Tags
No related merge requests found
...@@ -6,24 +6,45 @@ namespace Faust ...@@ -6,24 +6,45 @@ namespace Faust
TransformHelperButterfly<FPP, GPU2>::TransformHelperButterfly(const std::vector<MatGeneric<FPP,Cpu> *>& facts, const FPP lambda_ /*= (FPP)1.0*/, const bool optimizedCopy/*=false*/, const bool cloning_fact /*= true*/, const bool internal_call/*=false*/) TransformHelperButterfly<FPP, GPU2>::TransformHelperButterfly(const std::vector<MatGeneric<FPP,Cpu> *>& facts, const FPP lambda_ /*= (FPP)1.0*/, const bool optimizedCopy/*=false*/, const bool cloning_fact /*= true*/, const bool internal_call/*=false*/)
{ {
int i = 0; int i = 0;
auto size = this->getNbRow(); auto size = facts[0]->getNbRow();
// for(auto csr_fac: facts) // for(auto csr_fac: facts)
// use rather recorded factors in the Faust::Transform because one might have been multiplied with lambda_ // use rather recorded factors in the Faust::Transform because one might have been multiplied with lambda_
auto log2nf = 1 << (this->size() - 1); auto log2nf = 1 << (facts.size() - 1);
has_permutation = (log2nf - this->getNbRow()) == 0; has_permutation = (log2nf - size) == 0;
auto end_it = has_permutation?this->end()-1:this->end(); auto end_it = has_permutation?this->end()-1:this->end();
for(auto csr_fac_it = this->begin(); csr_fac_it != end_it; csr_fac_it++) for(auto gen_fac: facts)
{ {
auto csr_fac = *csr_fac_it; auto csr_fac = dynamic_cast<const MatSparse<FPP, Cpu>*>(gen_fac);
opt_factors.insert(opt_factors.begin(), ButterflyMat<FPP, GPU2>(*dynamic_cast<const MatSparse<FPP, Cpu>*>(csr_fac), i++)); if(csr_fac == nullptr)
this->push_back(csr_fac); throw std::runtime_error("TransformHelperButterfly can receive only MatSparse CSR matrices");
if(i < facts.size()-1 || ! has_permutation)
{
if( i == 0)
{
auto mul_csr = new MatSparse<FPP, Cpu>(*csr_fac);
*mul_csr *= lambda_;
opt_factors.insert(opt_factors.begin(),
ButterflyMat<FPP, GPU2>(*mul_csr, i++));
this->push_back(mul_csr);
}
else
{
opt_factors.insert(opt_factors.begin(),
ButterflyMat<FPP, GPU2>(*csr_fac, i++));
this->push_back(csr_fac);
}
}
} }
if(has_permutation) if(has_permutation)
{ {
// set the permutation factor // set the permutation factor
auto csr_fac = dynamic_cast<const MatSparse<FPP, Cpu>*>(*(this->end()-1)); auto csr_fac = dynamic_cast<const MatSparse<FPP, Cpu>*>(*(facts.end()-1));
this->push_back(csr_fac); this->push_back(csr_fac);
d_perm.resize(size); d_perm.resize(size);
if(csr_fac->getNonZeros() != size)
throw std::runtime_error("Permutation matrix is not valid");
// only ones should be enough because this is a permutation matrix but it could be normalized // only ones should be enough because this is a permutation matrix but it could be normalized
d_perm = Vect<FPP, GPU2>(size, csr_fac->getValuePtr()); d_perm = Vect<FPP, GPU2>(size, csr_fac->getValuePtr());
perm_ids = new int[size]; perm_ids = new int[size];
...@@ -35,15 +56,16 @@ namespace Faust ...@@ -35,15 +56,16 @@ namespace Faust
void TransformHelperButterfly<FPP, GPU2>::multiply(const FPP* A, int A_ncols, FPP* C) void TransformHelperButterfly<FPP, GPU2>::multiply(const FPP* A, int A_ncols, FPP* C)
{ {
MatDense<FPP, GPU2> gpu_X(d_perm.size(), A_ncols, A); MatDense<FPP, GPU2> gpu_X(this->getNbRow(), A_ncols, A);
int i = 0;
if(has_permutation) if(has_permutation)
gpu_X.eltwise_mul(d_perm, perm_ids); gpu_X.eltwise_mul(d_perm, perm_ids);
for(auto gpu_bmat: opt_factors) for(auto gpu_bmat: opt_factors)
gpu_bmat.multiply(gpu_X); gpu_bmat.multiply(gpu_X);
gpu_X.tocpu(C); gpu_X.tocpu(C, nullptr);
} }
...@@ -51,7 +73,7 @@ namespace Faust ...@@ -51,7 +73,7 @@ namespace Faust
Vect<FPP, Cpu> TransformHelperButterfly<FPP, GPU2>::multiply(const Vect<FPP, Cpu>& x) Vect<FPP, Cpu> TransformHelperButterfly<FPP, GPU2>::multiply(const Vect<FPP, Cpu>& x)
{ {
Vect<FPP, Cpu> y; Vect<FPP, Cpu> y;
y.resize(d_perm.size()); y.resize(this->getNbRow());
multiply(x.getData(), y.getData()); multiply(x.getData(), y.getData());
return y; return y;
} }
...@@ -67,7 +89,7 @@ namespace Faust ...@@ -67,7 +89,7 @@ namespace Faust
Vect<FPP, Cpu> TransformHelperButterfly<FPP, GPU2>::multiply(const FPP* x) Vect<FPP, Cpu> TransformHelperButterfly<FPP, GPU2>::multiply(const FPP* x)
{ {
Vect<FPP, Cpu> y; Vect<FPP, Cpu> y;
y.resize(d_perm.size()); y.resize(this->getNbRow());
multiply(x, 1, y.getData()); multiply(x, 1, y.getData());
return y; return y;
} }
...@@ -77,7 +99,7 @@ namespace Faust ...@@ -77,7 +99,7 @@ namespace Faust
MatDense<FPP, Cpu> TransformHelperButterfly<FPP, GPU2>::multiply(const MatDense<FPP,Cpu> &A) MatDense<FPP, Cpu> TransformHelperButterfly<FPP, GPU2>::multiply(const MatDense<FPP,Cpu> &A)
{ {
MatDense<FPP, Cpu> out; MatDense<FPP, Cpu> out;
out.resize(d_perm.size(), A.getNbCol()); out.resize(this->getNbRow(), A.getNbCol());
multiply(A.getData(), A.getNbCol(), out.getData()); multiply(A.getData(), A.getNbCol(), out.getData());
return out; return out;
} }
...@@ -85,7 +107,7 @@ namespace Faust ...@@ -85,7 +107,7 @@ namespace Faust
template<typename FPP> template<typename FPP>
MatDense<FPP, Cpu> TransformHelperButterfly<FPP,GPU2>::multiply(const MatSparse<FPP,Cpu> &X) MatDense<FPP, Cpu> TransformHelperButterfly<FPP,GPU2>::multiply(const MatSparse<FPP,Cpu> &X)
{ {
return multiply(MatDense<FPP, GPU2>(X)); return multiply(MatDense<FPP, Cpu>(X));
} }
template<typename FPP> template<typename FPP>
...@@ -112,9 +134,8 @@ namespace Faust ...@@ -112,9 +134,8 @@ namespace Faust
ButterflyMat<FPP, Cpu> cpu_bmat(factor, level); ButterflyMat<FPP, Cpu> cpu_bmat(factor, level);
auto cpu_d1 = cpu_bmat.getD1(); auto cpu_d1 = cpu_bmat.getD1();
auto cpu_d2 = cpu_bmat.getD2(); auto cpu_d2 = cpu_bmat.getD2();
d1 = Vect<FPP, GPU2>(cpu_d1.size(), cpu_d1.diagonal().data()); d1 = Vect<FPP, GPU2>(cpu_d1.rows(), cpu_d1.diagonal().data());
d2 = Vect<FPP, GPU2>(cpu_d2.size(), cpu_d2.diagonal().data()); d2 = Vect<FPP, GPU2>(cpu_d2.rows(), cpu_d2.diagonal().data());
auto sd_ids_vec = cpu_bmat.get_subdiag_ids(); auto sd_ids_vec = cpu_bmat.get_subdiag_ids();
subdiag_ids = new int[sd_ids_vec.size()]; subdiag_ids = new int[sd_ids_vec.size()];
memcpy(subdiag_ids, sd_ids_vec.data(), sizeof(int) * sd_ids_vec.size()); memcpy(subdiag_ids, sd_ids_vec.data(), sizeof(int) * sd_ids_vec.size());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment