Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 1bcaebec authored by hhakim's avatar hhakim
Browse files

Update HierarchicalFactFFT, Palm4MSAFFT and their tests.

- Adding a relevant test for HierarchicalFactFFT
- Replacing in HierarchicalFact the attribute object palm_global by an object pointer in order to implement polymorhism in HierarchicalFact (which uses Palm4MSA) and HierarchicalFactFFT (which uses Palm4MSAFFT). A destructor is here to delete the object from heap.
- Adding cmp test for each factor (of Faust approximation) in test_palm4MSAFFT.cpp.in (Relative diff. on Uhat each factor is not greater than 10^-4 comparatively to matlab results).
- Fixing mem. leak (data buffer) in Faust::Palm4MSAFFT::compute_D().
- Correcting error of D gradient computation in compute_D_grad_over_c().
- Implementing missing Faust::Palm4MSAFFT::compute_c() with empty_function because Palm4MSAFFT has a constant step size.
- Forcing ParamsPalmFFT.constant_step_size to be always true (as in matlab script).
parent 0b54ea70
Branches
Tags 2.3.1
No related merge requests found
Pipeline #833840 skipped
Showing
with 238 additions and 133 deletions
...@@ -14,44 +14,48 @@ ...@@ -14,44 +14,48 @@
#include <iostream> #include <iostream>
#include <iomanip> #include <iomanip>
using namespace Faust;
typedef @TEST_FPP@ FPP; typedef @TEST_FPP@ FPP;
typedef @TEST_FPP2@ FPP2; typedef @TEST_FPP2@ FPP2;
void doit(HierarchicalFact<FPP,Cpu,FPP2>* hierfact, int argc, FPP expectedLambda, FPP2 epsilon, Faust::MatDense<FPP, Cpu> & U, Faust::MatDense<FPP, Cpu> & Lap, vector<Faust::MatDense<FPP,Cpu>>& ref_facts);
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
vector<Faust::MatDense<FPP,Cpu>> ref_facts;
if (typeid(FPP) == typeid(double)) if (typeid(FPP) == typeid(double))
{ {
cout<<"floating point precision == double"<<endl; cout<<"floating point precision == double"<<endl;
} }
if (typeid(FPP) == typeid(float)) if (typeid(FPP) == typeid(float))
{ {
cout<<"floating point precision == float"<<endl; cout<<"floating point precision == float"<<endl;
} }
//TODO: this test is for now totally irelevant, we must get a Laplacian and a Fourier graph matrix to test properly HierarchicalFactFFT string configFilename = "@FAUST_DATA_MAT_DIR@/HierarchicalFactFFT_test_U_L_params.mat";
//default value string U_Filename = configFilename;
string configFilename = "@FAUST_CONFIG_MAT_DIR@/config_hierarchical_fact.mat"; string Lap_Filename = configFilename;
string MatrixFilename = "@FAUST_DATA_MAT_DIR@/matrix_hierarchical_fact.mat";
if (argc >= 3) if (argc >= 3)
{ {
MatrixFilename = argv[1]; U_Filename = argv[1];
configFilename = argv[2]; Lap_Filename = argv[2];
configFilename = argv[3];
} }
FPP expectedLambda = 0; FPP expectedLambda = 0;
if (argc >= 4) if (argc >= 5)
expectedLambda = atof(argv[3]); expectedLambda = atof(argv[3]);
FPP2 epsilon = 0.0001; FPP2 epsilon = 0.0001;
if (argc >= 5) if (argc >= 6)
epsilon = atof(argv[4]); epsilon = atof(argv[4]);
char transposedMatrix='N'; char transposedMatrix='N';
if (argc >= 6) if (argc >= 7)
transposedMatrix=argv[5][0]; transposedMatrix=argv[5][0];
...@@ -94,7 +98,6 @@ int main(int argc, char* argv[]) ...@@ -94,7 +98,6 @@ int main(int argc, char* argv[])
} }
// useless for CPU but use for compatibility with GPU // useless for CPU but use for compatibility with GPU
Faust::BlasHandle<Cpu> blasHandle; Faust::BlasHandle<Cpu> blasHandle;
Faust::SpBlasHandle<Cpu> spblasHandle; Faust::SpBlasHandle<Cpu> spblasHandle;
...@@ -102,34 +105,84 @@ int main(int argc, char* argv[]) ...@@ -102,34 +105,84 @@ int main(int argc, char* argv[])
// parameter setting // parameter setting
Faust::ParamsFFT<FPP,Cpu,FPP2> params; Faust::ParamsFFT<FPP,Cpu,FPP2> params;
init_params_from_matiofile<FPP,Cpu,FPP2>(params,configFilename.c_str(),"params"); init_params_from_matiofile<FPP,Cpu,FPP2>(params,configFilename.c_str(),"params");
// params.init_D = Faust::MatDense<FPP,Cpu>::eye(params.m_nbRow, params.m_nbCol);
// init_faust_mat_from_matio(params.init_D,U_Filename.c_str(),"init_D");
// cout << "init_lambda before overridding: " << params.init_lambda << endl;
// params.init_lambda = 1.0;
// params.isConstantStepSize = true;
params.isFactSideLeft = false; //false changes the accurracy to a lot better (but it's also very different from the matlab script's results)
init_faust_mat_from_matio(params.init_D,configFilename.c_str(),"init_D");
params.isVerbose = true;
params.Display(); params.Display();
Faust::MatDense<FPP,Cpu> tmp;
char mat_names[params.m_nbFact][7];
for(int i = 0; i < params.m_nbFact; i++)
{
sprintf(mat_names[i], "ref_f%d", i+1);
printf("%s\n", mat_names[i]);
}
for(int i = 0; i < params.m_nbFact; i++)
{
init_faust_mat_from_matio(tmp, configFilename.c_str(),mat_names[i]);
ref_facts.push_back(tmp);
}
cout << "norm init_D: " << params.init_D.norm() << endl;
cout << "init_D: " << endl; params.init_D.Display();
// matrix to be factorized // matrix to be factorized
Faust::MatDense<FPP,Cpu> matrix; Faust::MatDense<FPP,Cpu> U, Lap;
init_faust_mat_from_matio(matrix,MatrixFilename.c_str(),"matrix"); init_faust_mat_from_matio(U,U_Filename.c_str(),"U");
init_faust_mat_from_matio(Lap, Lap_Filename.c_str(), "Lap");
// transposed the matrix if needed // transposed the matrix if needed
if (transposedMatrix == 'T') if (transposedMatrix == 'T')
matrix.transpose(); U.transpose();
//algorithm //algorithm
Faust::HierarchicalFactFFT<FPP,Cpu,FPP2> hierFact(matrix,matrix,params,blasHandle,spblasHandle); Faust::HierarchicalFactFFT<FPP,Cpu,FPP2> hierFact(U,Lap,params,blasHandle,spblasHandle);
Faust::HierarchicalFact<FPP,Cpu,FPP2> hierFact_(U,params,blasHandle,spblasHandle);
doit(&hierFact, argc, expectedLambda, epsilon, U, Lap, ref_facts);
doit(&hierFact_, argc, expectedLambda, epsilon, U, Lap, ref_facts);
blasHandle.Destroy();
spblasHandle.Destroy();
return 0;
}
void doit(HierarchicalFact<FPP,Cpu,FPP2>* hierFact, int argc, FPP expectedLambda, FPP2 epsilon, Faust::MatDense<FPP, Cpu> & U, Faust::MatDense<FPP, Cpu> & Lap, vector<Faust::MatDense<FPP, Cpu>> & ref_facts)
{
Faust::Timer t1; Faust::Timer t1;
t1.start(); t1.start();
hierFact.compute_facts(); hierFact->compute_facts();
t1.stop(); t1.stop();
#ifdef __COMPILE_TIMERS__ #ifdef __COMPILE_TIMERS__
hierFact.print_timers(); hierFact->print_timers();
//hierFact.print_prox_timers(); //hierFact.print_prox_timers();
#endif #endif
cout <<"total hierarchical fact = "<<t1.get_time()<<endl; cout <<"total hierarchical fact = "<<t1.get_time()<<endl;
vector<Faust::MatSparse<FPP,Cpu> > facts; vector<Faust::MatSparse<FPP,Cpu> > facts;
hierFact.get_facts(facts); hierFact->get_facts(facts);
FPP lambda = hierFact.get_lambda(); Faust::MatDense<FPP,Cpu> tmp;
for(int i=0; i < ref_facts.size(); i++)
{
tmp = facts[i];
tmp -= ref_facts[i];
cout << "relerr fact " << i << ": " << tmp.norm()/ref_facts[i].norm() << endl;
}
FPP lambda = hierFact->get_lambda();
if (argc >= 3) if (argc >= 3)
{ {
if (Faust::fabs(lambda - expectedLambda) > epsilon) if (Faust::fabs(lambda - expectedLambda) > epsilon)
...@@ -140,7 +193,7 @@ int main(int argc, char* argv[]) ...@@ -140,7 +193,7 @@ int main(int argc, char* argv[])
} }
} }
(facts[0]) *= hierFact.get_lambda(); (facts[0]) *= hierFact->get_lambda();
// transform the sparse matrix into generic one // transform the sparse matrix into generic one
std::vector<Faust::MatGeneric<FPP,Cpu> *> list_fact_generic; std::vector<Faust::MatGeneric<FPP,Cpu> *> list_fact_generic;
list_fact_generic.resize(facts.size()); list_fact_generic.resize(facts.size());
...@@ -148,78 +201,82 @@ int main(int argc, char* argv[]) ...@@ -148,78 +201,82 @@ int main(int argc, char* argv[])
list_fact_generic[i]=facts[i].Clone(); list_fact_generic[i]=facts[i].Clone();
Faust::Transform<FPP,Cpu> hierFactCore(list_fact_generic); Faust::Transform<FPP,Cpu> hierFactCore(list_fact_generic);
cout << "Uhat Faust: " << endl;
hierFactCore.Display();
for (int i=0;i<list_fact_generic.size();i++) for (int i=0;i<list_fact_generic.size();i++)
delete list_fact_generic[i]; delete list_fact_generic[i];
char nomFichier[100];
string outputFile="@FAUST_BIN_TEST_OUTPUT_DIR@/hier_fact_factorisation.dat";
//WARNING no implemented
// hierFactCore.print_file(outputFile.c_str());
//write the given factorisation into a mat file
stringstream outputFilename;
outputFilename<<"@FAUST_BIN_TEST_OUTPUT_DIR@/"<<configFileBodyFile<<"_factorisation.mat";
std::cout<<"**************** WRITING FACTORISATION INTO ****************"<<std::endl;
std::cout<<"output filename : "<<outputFilename.str();
//WARNING no implemented
// hierFactCore.print_file(outputFile.c_str());
// modif NB : v1102 not implemented
//write_faust_core_into_matfile(hierFactCore,outputFilename.str().c_str(),"fact");
//relativeError //relativeError
Faust::MatDense<FPP,Cpu> faustProduct; Faust::MatDense<FPP,Cpu> faustProduct;
faustProduct=hierFactCore.get_product(); faustProduct=hierFactCore.get_product();
faustProduct-=matrix; faustProduct-=U;
FPP2 relativeError = Faust::fabs(faustProduct.norm()/matrix.norm()); FPP2 relativeError = Faust::fabs(faustProduct.norm()/U.norm());
std::cout<<std::endl; std::cout<<std::endl;
std::cout<<"**************** RELATIVE ERROR BETWEEN FAUST AND DATA MATRIX **************** "<<std::endl; std::cout<<"**************** RELATIVE ERROR BETWEEN FAUST AND Fourier MATRIX **************** "<<std::endl;
std::cout<<" "<<relativeError<<std::endl<<std::endl; std::cout<<" "<<relativeError<<std::endl<<std::endl;
HierarchicalFactFFT<FPP,Cpu,FPP2> *hierFactFFT;
if(hierFactFFT = dynamic_cast<HierarchicalFactFFT<FPP,Cpu, FPP2>*>(hierFact))
//time comparison between matrix vector product and faust-vector product
int niterTimeComp = 10;
if (niterTimeComp > 0)
{ {
//relativeError 2
Faust::MatDense<FPP,Cpu> lapProduct, lapErr;
const Faust::MatDense<FPP,Cpu>& D = hierFactFFT->get_D();
lapErr = hierFactCore.get_product();
lapProduct = hierFactCore.get_product();
lapProduct.transpose();
// lapErr = Uhat*D*Uhat'
lapErr.multiplyRight(D);
lapErr.multiplyRight(lapProduct);
cout << "Lap norm: " << Lap.norm() << endl;
cout << "norm Uhat*D*Uhat':" << lapErr.norm() << endl;
// lapErr = Uhat*D*Uhat'-Lap
lapErr-=Lap;
FPP2 relativeError2 = Faust::fabs(lapErr.norm()/Lap.norm());
Faust::Timer tdense; std::cout<<std::endl;
Faust::Timer tfaust; std::cout<<"**************** RELATIVE ERROR BETWEEN FAUST*D*FAUST' AND Lap MATRIX **************** "<<std::endl;
Faust::Vect<FPP,Cpu> x(matrix.getNbCol()); std::cout<<" "<<relativeError2<<std::endl<<std::endl;
Faust::Vect<FPP,Cpu> ydense(matrix.getNbRow());
Faust::Vect<FPP,Cpu> yfaust(hierFactCore.getNbRow());
for (int i=0;i<niterTimeComp;i++)
{
//random initilisation of vector x
for (int j=0;j<x.size();j++)
{
x[j]=std::rand()*2.0/RAND_MAX-1.0;
}
tdense.start(); cout<< " D info:" << endl;
ydense = matrix * x; cout << " D fro. norm: " << D.norm() << endl;
tdense.stop(); cout << " D nnz: " << D.getNonZeros() << endl;
tfaust.start(); //time comparison between matrix vector product and faust-vector product
yfaust = hierFactCore * x; int niterTimeComp = 10;
tfaust.stop(); if (niterTimeComp > 0)
{
} Faust::Timer tdense;
std::cout<<std::endl; Faust::Timer tfaust;
Faust::Vect<FPP,Cpu> x(U.getNbCol());
Faust::Vect<FPP,Cpu> ydense(U.getNbRow());
Faust::Vect<FPP,Cpu> yfaust(hierFactCore.getNbRow());
for (int i=0;i<niterTimeComp;i++)
{
//random initilisation of vector x
for (int j=0;j<x.size();j++)
{
x[j]=std::rand()*2.0/RAND_MAX-1.0;
}
std::cout<<"**************** TIME COMPARISON MATRIX VECTOR PRODUCT **************** "<<std::endl; tdense.start();
std::cout<<" TIME SPEED-UP : "<<tdense.get_time()/tfaust.get_time()<<std::endl; ydense = U * x;
std::cout<<" MEAN TIME dense : "<<tdense.get_time()/((float) niterTimeComp)<<std::endl; tdense.stop();
std::cout<<" MEAN TIME faust : "<<tfaust.get_time()/((float) niterTimeComp)<<std::endl;
cout<<"lambda="<<std::setprecision(20)<<hierFact.get_lambda()<<endl;
}
blasHandle.Destroy(); tfaust.start();
spblasHandle.Destroy(); yfaust = hierFactCore * x;
tfaust.stop();
}
std::cout<<std::endl;
return 0; std::cout<<"**************** TIME COMPARISON MATRIX VECTOR PRODUCT **************** "<<std::endl;
std::cout<<" TIME SPEED-UP : "<<tdense.get_time()/tfaust.get_time()<<std::endl;
std::cout<<" MEAN TIME dense : "<<tdense.get_time()/((float) niterTimeComp)<<std::endl;
std::cout<<" MEAN TIME faust : "<<tfaust.get_time()/((float) niterTimeComp)<<std::endl;
cout<<"lambda="<<std::setprecision(20)<<hierFactFFT->get_lambda()<<endl;
}
}
} }
/****************************************************************************/ /*****************************************************************************//*{{{*/
/* Description: */ /* Description: */
/* For more information on the FAuST Project, please visit the website */ /* For more information on the FAuST Project, please visit the website */
/* of the project : <http://faust.inria.fr> */ /* of the project : <http://faust.inria.fr> */
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
/* approximations of matrices and applications", Journal of Selected */ /* approximations of matrices and applications", Journal of Selected */
/* Topics in Signal Processing, 2016. */ /* Topics in Signal Processing, 2016. */
/* <https://hal.archives-ouvertes.fr/hal-01167948v1> */ /* <https://hal.archives-ouvertes.fr/hal-01167948v1> */
/****************************************************************************/ /****************************************************************************//*}}}*//*}}}*//*}}}*/
#include "faust_MatDense.h" #include "faust_MatDense.h"
#include "faust_Params.h" #include "faust_Params.h"
#include "faust_ParamsPalmFFT.h" #include "faust_ParamsPalmFFT.h"
...@@ -60,7 +60,8 @@ typedef @TEST_FPP2@ FPP2; ...@@ -60,7 +60,8 @@ typedef @TEST_FPP2@ FPP2;
*/ */
int main() int main()
{ {
Faust::MatDense<FPP,Cpu> data, init_facts1, init_facts2, init_D; Faust::MatDense<FPP,Cpu> data, init_facts1, init_facts2, init_D, ref_Uhat1, ref_Uhat2, err_Uhat1, err_Uhat2;
FPP tmp;
string input_path = "@FAUST_DATA_MAT_DIR@/ref_test_PALM4SMA_FFT2.mat"; string input_path = "@FAUST_DATA_MAT_DIR@/ref_test_PALM4SMA_FFT2.mat";
...@@ -70,6 +71,8 @@ int main() ...@@ -70,6 +71,8 @@ int main()
// data is a Laplacian (symmetric matrix) // data is a Laplacian (symmetric matrix)
init_faust_mat_from_matio(data, configPalm2Filename, "data"); init_faust_mat_from_matio(data, configPalm2Filename, "data");
init_faust_mat_from_matio(ref_Uhat1, configPalm2Filename, "Uhat1");
init_faust_mat_from_matio(ref_Uhat2, configPalm2Filename, "Uhat2");
init_faust_mat_from_matio(init_facts1, configPalm2Filename, "p_init_facts1"); init_faust_mat_from_matio(init_facts1, configPalm2Filename, "p_init_facts1");
init_faust_mat_from_matio(init_facts2, configPalm2Filename, "p_init_facts2"); init_faust_mat_from_matio(init_facts2, configPalm2Filename, "p_init_facts2");
...@@ -117,7 +120,8 @@ int main() ...@@ -117,7 +120,8 @@ int main()
Faust::StoppingCriterion<FPP2> crit(niter); Faust::StoppingCriterion<FPP2> crit(niter);
Faust::ParamsPalmFFT<FPP,Cpu,FPP2> params(data, nfacts, cons, initFact, init_D, crit, verbose, updateWay, initLambda, true, step_size); Faust::ParamsPalmFFT<FPP,Cpu,FPP2> params(data, nfacts, cons, initFact, init_D, crit, verbose, updateWay, initLambda, step_size);
params.isVerbose = true;
#ifdef DEBUG #ifdef DEBUG
params.Display(); params.Display();
...@@ -138,9 +142,24 @@ int main() ...@@ -138,9 +142,24 @@ int main()
// palm2.next_step(); // palm2.next_step();
palm2.compute_facts(); palm2.compute_facts();
std::vector<Faust::MatDense<FPP,Cpu> >& full_facts = const_cast< std::vector<Faust::MatDense<FPP,Cpu> >&>(palm2.get_facts()); std::vector<Faust::MatDense<FPP,Cpu> >& full_facts = const_cast< std::vector<Faust::MatDense<FPP,Cpu> >&>(palm2.get_facts());
FPP lambda = palm2.get_lambda(); FPP lambda = palm2.get_lambda();
(full_facts[0]) *= lambda; (full_facts[0]) *= lambda;
err_Uhat1 = ref_Uhat1;
err_Uhat1 -= full_facts[0];
tmp = ref_Uhat1.norm();
cout << "RE for Uhat1: " << err_Uhat1.norm()/tmp << endl;
err_Uhat2 = ref_Uhat2;
err_Uhat2 -= full_facts[1];
tmp = ref_Uhat2.norm();
cout << "RE for Uhat2: " << err_Uhat2.norm()/tmp << endl;
Faust::Transform<FPP, Cpu>* t = new Faust::Transform<FPP, Cpu>(full_facts); Faust::Transform<FPP, Cpu>* t = new Faust::Transform<FPP, Cpu>(full_facts);
//relativeError //relativeError
...@@ -165,7 +184,7 @@ int main() ...@@ -165,7 +184,7 @@ int main()
double relativeError = Faust::fabs(mat.norm()/data.norm()); double relativeError = Faust::fabs(mat.norm()/data.norm());
std::cout<<std::endl; std::cout<<std::endl;
std::cout<<"**************** RELATIVE ERROR BETWEEN FAUST*D*Faust.transpose AND DATA LAPLACIAN MATRIX **************** "<<std::endl; std::cout<<"**************** RELATIVE ERROR BETWEEN FAUST_UHAT*D*Faust_UHAT.transpose AND DATA LAPLACIAN MATRIX **************** "<<std::endl;
std::cout<< "\t\t" << relativeError<<std::endl<<std::endl; std::cout<< "\t\t" << relativeError<<std::endl<<std::endl;
......
...@@ -7,7 +7,6 @@ import sys ...@@ -7,7 +7,6 @@ import sys
import numpy as np import numpy as np
from scipy.io import savemat,loadmat from scipy.io import savemat,loadmat
from numpy.linalg import norm from numpy.linalg import norm
from pyfaust import Faust
import math import math
class TestFaustPy(unittest.TestCase): class TestFaustPy(unittest.TestCase):
...@@ -898,6 +897,7 @@ if __name__ == "__main__": ...@@ -898,6 +897,7 @@ if __name__ == "__main__":
# (to find pyfaust module) # (to find pyfaust module)
sys.path.append(sys.argv[1]) sys.path.append(sys.argv[1])
del sys.argv[1] # deleted to avoid interfering with unittest del sys.argv[1] # deleted to avoid interfering with unittest
from pyfaust import Faust
if(len(sys.argv) > 1): if(len(sys.argv) > 1):
#ENOTE: test only a single test if name passed on command line #ENOTE: test only a single test if name passed on command line
singleton = unittest.TestSuite() singleton = unittest.TestSuite()
......
...@@ -79,11 +79,11 @@ namespace Faust ...@@ -79,11 +79,11 @@ namespace Faust
HierarchicalFact(const Faust::MatDense<FPP,DEVICE>& M, const Faust::Params<FPP,DEVICE,FPP2>& params_, Faust::BlasHandle<DEVICE> cublasHandle, SpBlasHandle<DEVICE> cusparseHandle); HierarchicalFact(const Faust::MatDense<FPP,DEVICE>& M, const Faust::Params<FPP,DEVICE,FPP2>& params_, Faust::BlasHandle<DEVICE> cublasHandle, SpBlasHandle<DEVICE> cusparseHandle);
void get_facts(Faust::Transform<FPP,DEVICE> &)const; void get_facts(Faust::Transform<FPP,DEVICE> &)const;
void get_facts(std::vector<Faust::MatSparse<FPP,DEVICE> >&)const; void get_facts(std::vector<Faust::MatSparse<FPP,DEVICE> >&)const;
void get_facts(std::vector<Faust::MatDense<FPP,DEVICE> >& fact)const{fact = palm_global.get_facts();} void get_facts(std::vector<Faust::MatDense<FPP,DEVICE> >& fact)const{fact = palm_global->get_facts();}
void compute_facts(); void compute_facts();
FPP get_lambda()const{return palm_global.get_lambda();} FPP get_lambda()const{return palm_global->get_lambda();}
const std::vector<std::vector< FPP> >& get_errors()const; const std::vector<std::vector< FPP> >& get_errors()const;
virtual ~HierarchicalFact();
private: private:
void init(); void init();
...@@ -101,7 +101,7 @@ namespace Faust ...@@ -101,7 +101,7 @@ namespace Faust
int m_indFact ; //indice de factorisation (!= Faust::Palm4MSA::m_indFact : indice de facteur) int m_indFact ; //indice de factorisation (!= Faust::Palm4MSA::m_indFact : indice de facteur)
int nbFact; // nombre de factorisations (!= Faust::Palm4MSA::nbFact : nombre de facteurs) int nbFact; // nombre de factorisations (!= Faust::Palm4MSA::nbFact : nombre de facteurs)
Faust::Palm4MSA<FPP,DEVICE,FPP2> palm_2; Faust::Palm4MSA<FPP,DEVICE,FPP2> palm_2;
Faust::Palm4MSA<FPP,DEVICE,FPP2> palm_global; Faust::Palm4MSA<FPP,DEVICE,FPP2>* palm_global;
const FPP default_lambda; // initial value of lambda for factorization into two factors const FPP default_lambda; // initial value of lambda for factorization into two factors
//std::vector<Faust::MatDense<FPP,DEVICE> > S; //std::vector<Faust::MatDense<FPP,DEVICE> > S;
std::vector<const Faust::ConstraintGeneric*> cons_tmp_global; std::vector<const Faust::ConstraintGeneric*> cons_tmp_global;
......
...@@ -71,7 +71,7 @@ Faust::HierarchicalFact<FPP,DEVICE,FPP2>::HierarchicalFact(const Faust::MatDense ...@@ -71,7 +71,7 @@ Faust::HierarchicalFact<FPP,DEVICE,FPP2>::HierarchicalFact(const Faust::MatDense
m_isVerbose(params_.isVerbose), m_isVerbose(params_.isVerbose),
nbFact(params_.m_nbFact-1), nbFact(params_.m_nbFact-1),
palm_2(Palm4MSA<FPP,DEVICE,FPP2>(M,params_, cublasHandle, false)), palm_2(Palm4MSA<FPP,DEVICE,FPP2>(M,params_, cublasHandle, false)),
palm_global(Palm4MSA<FPP,DEVICE,FPP2>(M,params_, cublasHandle, true)), palm_global(new Palm4MSA<FPP,DEVICE,FPP2>(M,params_, cublasHandle, true)),
cons_tmp_global(vector<const Faust::ConstraintGeneric*>()), cons_tmp_global(vector<const Faust::ConstraintGeneric*>()),
default_lambda(params_.init_lambda), default_lambda(params_.init_lambda),
isFactorizationComputed(false), isFactorizationComputed(false),
...@@ -101,8 +101,8 @@ t_init.start(); ...@@ -101,8 +101,8 @@ t_init.start();
cons_tmp_global.push_back(cons[1][m_indFact]); cons_tmp_global.push_back(cons[1][m_indFact]);
palm_global.set_constraint(cons_tmp_global); palm_global->set_constraint(cons_tmp_global);
palm_global.init_fact(1); palm_global->init_fact(1);
#ifdef __COMPILE_TIMERS__ #ifdef __COMPILE_TIMERS__
...@@ -182,7 +182,7 @@ palm_2.init_local_timers(); ...@@ -182,7 +182,7 @@ palm_2.init_local_timers();
#ifdef __COMPILE_TIMERS__ #ifdef __COMPILE_TIMERS__
palm_2.print_local_timers(); palm_2.print_local_timers();
#endif #endif
palm_global.update_lambda_from_palm(palm_2); palm_global->update_lambda_from_palm(palm_2);
if (m_isFactSideLeft) if (m_isFactSideLeft)
...@@ -200,22 +200,22 @@ palm_2.print_local_timers(); ...@@ -200,22 +200,22 @@ palm_2.print_local_timers();
cons_tmp_global[m_indFact+1]=cons[1][m_indFact]; cons_tmp_global[m_indFact+1]=cons[1][m_indFact];
} }
palm_global.set_constraint(cons_tmp_global); palm_global->set_constraint(cons_tmp_global);
palm_global.init_fact_from_palm(palm_2, m_isFactSideLeft); palm_global->init_fact_from_palm(palm_2, m_isFactSideLeft);
#ifdef __COMPILE_TIMERS__ #ifdef __COMPILE_TIMERS__
palm_global.init_local_timers(); palm_global->init_local_timers();
#endif #endif
//while(palm_global.do_continue()) //while(palm_global->do_continue())
// palm_global.next_step(); // palm_global->next_step();
palm_global.compute_facts(); this->palm_global->compute_facts();
#ifdef __COMPILE_TIMERS__ #ifdef __COMPILE_TIMERS__
palm_global.print_local_timers(); palm_global->print_local_timers();
#endif #endif
palm_2.set_data(palm_global.get_res(m_isFactSideLeft, m_indFact)); palm_2.set_data(palm_global->get_res(m_isFactSideLeft, m_indFact));
compute_errors(); compute_errors();
...@@ -248,7 +248,7 @@ void Faust::HierarchicalFact<FPP,DEVICE,FPP2>::get_facts(std::vector<Faust::MatS ...@@ -248,7 +248,7 @@ void Faust::HierarchicalFact<FPP,DEVICE,FPP2>::get_facts(std::vector<Faust::MatS
exit(EXIT_FAILURE); exit(EXIT_FAILURE);
}*/ }*/
const std::vector<Faust::MatDense<FPP,DEVICE> >& full_facts = palm_global.get_facts(); const std::vector<Faust::MatDense<FPP,DEVICE> >& full_facts = palm_global->get_facts();
sparse_facts.resize(full_facts.size()); sparse_facts.resize(full_facts.size());
for (int i=0 ; i<sparse_facts.size() ; i++) for (int i=0 ; i<sparse_facts.size() ; i++)
{ {
...@@ -307,7 +307,7 @@ void Faust::HierarchicalFact<FPP,DEVICE,FPP2>::compute_errors() ...@@ -307,7 +307,7 @@ void Faust::HierarchicalFact<FPP,DEVICE,FPP2>::compute_errors()
delete Transform_facts[i]; delete Transform_facts[i];
const Faust::MatDense<FPP,DEVICE> estimate_mat = faust_Transform_tmp.get_product(cublas_handle, cusparse_handle); const Faust::MatDense<FPP,DEVICE> estimate_mat = faust_Transform_tmp.get_product(cublas_handle, cusparse_handle);
Faust::MatDense<FPP,DEVICE> data(palm_global.get_data()); Faust::MatDense<FPP,DEVICE> data(palm_global->get_data());
FPP2 data_norm = Faust::fabs(data.norm()); FPP2 data_norm = Faust::fabs(data.norm());
...@@ -327,7 +327,7 @@ template<typename FPP,Device DEVICE,typename FPP2> Faust::Timer Faust::Hierarchi ...@@ -327,7 +327,7 @@ template<typename FPP,Device DEVICE,typename FPP2> Faust::Timer Faust::Hierarchi
template<typename FPP,Device DEVICE,typename FPP2> template<typename FPP,Device DEVICE,typename FPP2>
void Faust::HierarchicalFact<FPP,DEVICE,FPP2>::print_timers()const void Faust::HierarchicalFact<FPP,DEVICE,FPP2>::print_timers()const
{ {
palm_global.print_global_timers(); palm_global->print_global_timers();
cout << "timers in Faust::HierarchicalFact :" << endl; cout << "timers in Faust::HierarchicalFact :" << endl;
cout << "t_init = " << t_init.get_time() << " s for "<< t_init.get_nb_call() << " calls" << endl; cout << "t_init = " << t_init.get_time() << " s for "<< t_init.get_nb_call() << " calls" << endl;
cout << "t_next_step = " << t_next_step.get_time() << " s for "<< t_next_step.get_nb_call() << " calls" << endl<<endl; cout << "t_next_step = " << t_next_step.get_time() << " s for "<< t_next_step.get_nb_call() << " calls" << endl<<endl;
...@@ -336,5 +336,9 @@ void Faust::HierarchicalFact<FPP,DEVICE,FPP2>::print_timers()const ...@@ -336,5 +336,9 @@ void Faust::HierarchicalFact<FPP,DEVICE,FPP2>::print_timers()const
} }
#endif #endif
template<typename FPP,Device DEVICE,typename FPP2>
Faust::HierarchicalFact<FPP,DEVICE,FPP2>::~HierarchicalFact()
{
delete this->palm_global;
}
#endif #endif
...@@ -12,18 +12,32 @@ namespace Faust ...@@ -12,18 +12,32 @@ namespace Faust
class HierarchicalFactFFT : public HierarchicalFact<FPP, DEVICE, FPP2> class HierarchicalFactFFT : public HierarchicalFact<FPP, DEVICE, FPP2>
{ {
Palm4MSAFFT<FPP,DEVICE,FPP2> palm_global;
static const char * m_className; static const char * m_className;
public: public:
//TODO: move def. code in .hpp //TODO: move def. code in .hpp
HierarchicalFactFFT(const MatDense<FPP,DEVICE>& U, const MatDense<FPP,DEVICE>& Lap,const ParamsFFT<FPP,DEVICE,FPP2>& params, BlasHandle<DEVICE> cublasHandle, SpBlasHandle<DEVICE> cusparseHandle): HierarchicalFact<FPP, DEVICE, FPP2>(U, params, cublasHandle, cusparseHandle), palm_global(Palm4MSAFFT<FPP,DEVICE,FPP2>(Lap, params, cublasHandle, true)) HierarchicalFactFFT(const MatDense<FPP,DEVICE>& U, const MatDense<FPP,DEVICE>& Lap, ParamsFFT<FPP,DEVICE,FPP2>& params, BlasHandle<DEVICE> cublasHandle, SpBlasHandle<DEVICE> cusparseHandle): HierarchicalFact<FPP, DEVICE, FPP2>(U, params, cublasHandle, cusparseHandle)
//TODO: verify if palm_global is really initialized after parent ctor call //TODO: verify if palm_global is really initialized after parent ctor call
{ {
if ((U.getNbRow() != params.m_nbRow) | (U.getNbCol() != params.m_nbCol)) if ((U.getNbRow() != params.m_nbRow) | (U.getNbCol() != params.m_nbCol))
handleError(m_className,"constructor : params and Fourier matrix U haven't compatible size"); handleError(m_className,"constructor : params and Fourier matrix U haven't compatible size");
if((Lap.getNbRow() != params.m_nbRow) | (Lap.getNbCol() != params.m_nbCol)) if((Lap.getNbRow() != params.m_nbRow) | (Lap.getNbCol() != params.m_nbCol))
handleError(m_className,"constructor : params and Laplacian matrix Lap haven't compatible size"); handleError(m_className,"constructor : params and Laplacian matrix Lap haven't compatible size");
delete this->palm_global;
cout << "HierarchicalFactFFT init_lambda:" << params.init_lambda << endl;
this->palm_global = new Palm4MSAFFT<FPP,DEVICE,FPP2>(Lap, params, cublasHandle, true);
}
const MatDense<FPP, DEVICE>& get_D()
{
return dynamic_cast<Palm4MSAFFT<FPP,Cpu,FPP2>*>(this->palm_global)->get_D();
}
void next_step()
{
this->palm_2.m_lambda = FPP(1.);
Faust::HierarchicalFact<FPP, DEVICE, FPP2>::next_step();
} }
......
...@@ -57,6 +57,7 @@ namespace Faust ...@@ -57,6 +57,7 @@ namespace Faust
template<typename FPP,Device DEVICE> class MatDense; template<typename FPP,Device DEVICE> class MatDense;
template<typename FPP,Device DEVICE> class Transform; template<typename FPP,Device DEVICE> class Transform;
template<typename FPP, Device DEVICE, typename FPP2> class HierarchicalFactFFT;
class ConstraintGeneric; class ConstraintGeneric;
template<typename FPP,Device DEVICE, typename FPP2> class Params; template<typename FPP,Device DEVICE, typename FPP2> class Params;
...@@ -76,7 +77,7 @@ namespace Faust ...@@ -76,7 +77,7 @@ namespace Faust
template<typename FPP,Device DEVICE,typename FPP2 = double> template<typename FPP,Device DEVICE,typename FPP2 = double>
class Palm4MSA class Palm4MSA
{ {
// friend class Faust::HierarchicalFactFFT<FPP,DEVICE, FPP2>;
public: public:
/*! /*!
...@@ -132,11 +133,11 @@ namespace Faust ...@@ -132,11 +133,11 @@ namespace Faust
*/ */
void init_fact_from_palm(const Palm4MSA& palm, bool isFactSideLeft); void init_fact_from_palm(const Palm4MSA& palm, bool isFactSideLeft);
const std::vector<Faust::MatDense<FPP,DEVICE> >& get_facts()const {return S;} const std::vector<Faust::MatDense<FPP,DEVICE> >& get_facts()const {return S;}
~Palm4MSA(){} virtual ~Palm4MSA(){}
protected: protected:
void check_constraint_validity(); void check_constraint_validity();
void compute_c(); virtual void compute_c();
virtual void compute_grad_over_c(); virtual void compute_grad_over_c();
void compute_projection(); void compute_projection();
void update_L(); void update_L();
......
...@@ -530,7 +530,6 @@ void Faust::Palm4MSA<FPP,DEVICE,FPP2>::compute_c() ...@@ -530,7 +530,6 @@ void Faust::Palm4MSA<FPP,DEVICE,FPP2>::compute_c()
t_local_compute_c.start(); t_local_compute_c.start();
#endif #endif
if (!isConstantStepSize) if (!isConstantStepSize)
{ {
faust_int flag1,flag2; faust_int flag1,flag2;
......
...@@ -18,13 +18,14 @@ namespace Faust { ...@@ -18,13 +18,14 @@ namespace Faust {
//TODO: another ctor (like in Palm4MSA) for hierarchical algo. use //TODO: another ctor (like in Palm4MSA) for hierarchical algo. use
Palm4MSAFFT(const ParamsPalmFFT<FPP, DEVICE, FPP2>& params, const BlasHandle<DEVICE> blasHandle, const bool isGlobal=false); Palm4MSAFFT(const ParamsPalmFFT<FPP, DEVICE, FPP2>& params, const BlasHandle<DEVICE> blasHandle, const bool isGlobal=false);
Palm4MSAFFT(const MatDense<FPP,DEVICE>& Lap, const ParamsFFT<FPP,DEVICE,FPP2> & params, const BlasHandle<DEVICE> blasHandle, const bool isGlobal); Palm4MSAFFT(const MatDense<FPP,DEVICE>& Lap, const ParamsFFT<FPP,DEVICE,FPP2> & params, const BlasHandle<DEVICE> blasHandle, const bool isGlobal);
virtual void next_step(); void next_step();
const MatDense<FPP, DEVICE>& get_D(); const MatDense<FPP, DEVICE>& get_D();
private: private:
virtual void compute_grad_over_c(); void compute_grad_over_c();
virtual void compute_lambda(); void compute_lambda();
void compute_D(); void compute_D();
void compute_D_grad_over_c(); void compute_D_grad_over_c();
void compute_c();
}; };
#include "faust_Palm4MSAFFT.hpp" #include "faust_Palm4MSAFFT.hpp"
......
...@@ -9,7 +9,6 @@ Palm4MSAFFT<FPP,DEVICE,FPP2>::Palm4MSAFFT(const ParamsPalmFFT<FPP, DEVICE, FPP2> ...@@ -9,7 +9,6 @@ Palm4MSAFFT<FPP,DEVICE,FPP2>::Palm4MSAFFT(const ParamsPalmFFT<FPP, DEVICE, FPP2>
template<typename FPP,Device DEVICE,typename FPP2> template<typename FPP,Device DEVICE,typename FPP2>
Palm4MSAFFT<FPP,DEVICE,FPP2>::Palm4MSAFFT(const MatDense<FPP,DEVICE>& Lap, const ParamsFFT<FPP,DEVICE,FPP2> & params, const BlasHandle<DEVICE> blasHandle, const bool isGlobal) : Palm4MSA<FPP,DEVICE,FPP2>(Lap, params, blasHandle, isGlobal), D(params.init_D) Palm4MSAFFT<FPP,DEVICE,FPP2>::Palm4MSAFFT(const MatDense<FPP,DEVICE>& Lap, const ParamsFFT<FPP,DEVICE,FPP2> & params, const BlasHandle<DEVICE> blasHandle, const bool isGlobal) : Palm4MSA<FPP,DEVICE,FPP2>(Lap, params, blasHandle, isGlobal), D(params.init_D)
{ {
} }
template <typename FPP, Device DEVICE, typename FPP2> template <typename FPP, Device DEVICE, typename FPP2>
...@@ -95,7 +94,6 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_grad_over_c() ...@@ -95,7 +94,6 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_grad_over_c()
// this->error = lambda*tmp1*lambda*tmp2'-data // this->error is data before this call // this->error = lambda*tmp1*lambda*tmp2'-data // this->error is data before this call
gemm(tmp1, tmp2, this->error, this->m_lambda*this->m_lambda, (FPP)-1.0, 'N', this->TorH, this->blas_handle); gemm(tmp1, tmp2, this->error, this->m_lambda*this->m_lambda, (FPP)-1.0, 'N', this->TorH, this->blas_handle);
if (idx==0 || idx==2) // computing L'*this->error first, then (L'*this->error)*R' if (idx==0 || idx==2) // computing L'*this->error first, then (L'*this->error)*R'
{ {
if (!this->isUpdateWayR2L) if (!this->isUpdateWayR2L)
...@@ -113,7 +111,7 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_grad_over_c() ...@@ -113,7 +111,7 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_grad_over_c()
gemm(tmp1, this->LorR, tmp2, this->m_lambda, (FPP) 0, 'N', this->TorH, this->blas_handle); gemm(tmp1, this->LorR, tmp2, this->m_lambda, (FPP) 0, 'N', this->TorH, this->blas_handle);
} }
// grad_over_c = 1/this->c*tmp3*tmp2 // grad_over_c = 1/this->c*tmp3*tmp2
gemm(tmp3, tmp2, this->grad_over_c, (FPP) 1.0/this->c, (FPP) (FPP) 0.0,'N','N', this->blas_handle); gemm(tmp3, tmp2, this->grad_over_c, (FPP) 1.0/this->c, (FPP) 0.0,'N','N', this->blas_handle);
} }
else // computing this->error*R' first, then L'*(this->error*lambda*LSRD*R') else // computing this->error*R' first, then L'*(this->error*lambda*LSRD*R')
...@@ -139,7 +137,6 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_grad_over_c() ...@@ -139,7 +137,6 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_grad_over_c()
} }
this->isGradComputed = true; this->isGradComputed = true;
#ifdef __COMPILE_TIMERS__ #ifdef __COMPILE_TIMERS__
...@@ -152,7 +149,7 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_grad_over_c() ...@@ -152,7 +149,7 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_grad_over_c()
template <typename FPP, Device DEVICE, typename FPP2> template <typename FPP, Device DEVICE, typename FPP2>
void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_lambda() void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_lambda()
{ {
//TODO: override parent's method // override parent's method
// Xhat = (S[0]*...*S[nfact-1])*D*(S[0]*...*S[nfact-1])' // Xhat = (S[0]*...*S[nfact-1])*D*(S[0]*...*S[nfact-1])'
// Xhat = LorR*D*LorR' // LorR equals the prod of all factors after their update iterations (in loop of next_step()) // Xhat = LorR*D*LorR' // LorR equals the prod of all factors after their update iterations (in loop of next_step())
MatDense<FPP,Cpu> tmp; MatDense<FPP,Cpu> tmp;
...@@ -170,7 +167,7 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_lambda() ...@@ -170,7 +167,7 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_lambda()
// reset LorR at the factor product to continue next iterations // reset LorR at the factor product to continue next iterations
this->LorR = tmp; this->LorR = tmp;
//then we finish the lambda computation with a sqrt() (Fro. norm) //then we finish the lambda computation with a sqrt() (Fro. norm)
this->m_lambda = std::sqrt(this->m_lambda); this->m_lambda = std::sqrt(/*Faust::abs(*/this->m_lambda);//);
// (that's an additional operation in Palm4MSAFFT) // (that's an additional operation in Palm4MSAFFT)
} }
...@@ -183,13 +180,15 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::next_step() ...@@ -183,13 +180,15 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::next_step()
this->compute_D(); this->compute_D();
} }
template <typename FPP, Device DEVICE, typename FPP2> template <typename FPP, Device DEVICE, typename FPP2>
void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_D() void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_D()
{ {
// besides to what the parent has done // besides to what the parent has done
// we need to update D // we need to update D
compute_D_grad_over_c(); compute_D_grad_over_c();
D_grad_over_c.scalarMultiply(this->m_lambda/this->c);
D -= D_grad_over_c; D -= D_grad_over_c;
//TODO: optimize MatSparse + no-copy (Eigen::DiagonalMatrix ?) //TODO: optimize MatSparse + no-copy (Eigen::DiagonalMatrix ?)
FPP * data = new FPP[D.getNbRow()*D.getNbCol()]; FPP * data = new FPP[D.getNbRow()*D.getNbCol()];
...@@ -197,20 +196,23 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_D() ...@@ -197,20 +196,23 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_D()
for(faust_unsigned_int i = 0; i < D.getNbCol();i++) for(faust_unsigned_int i = 0; i < D.getNbCol();i++)
data[i*D.getNbCol()+i] = D[i*D.getNbCol()+i]; data[i*D.getNbCol()+i] = D[i*D.getNbCol()+i];
D = MatDense<FPP,Cpu>(data, D.getNbRow(), D.getNbCol()); D = MatDense<FPP,Cpu>(data, D.getNbRow(), D.getNbCol());
delete data;
} }
template <typename FPP, Device DEVICE, typename FPP2> template <typename FPP, Device DEVICE, typename FPP2>
void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_D_grad_over_c() void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_D_grad_over_c()
{ {
// grad = 0.5*LorR'*(LorR*D*LorR' - X)*LorR // Uhat = lambda*LorR
// grad = 0.5*Uhat'*(Uhat*D*Uhat' - X)*Uhat
MatDense<FPP, Cpu> tmp; MatDense<FPP, Cpu> tmp;
//compute_lambda has already compute D_grad_over_c = LorR*D*LorR' //compute_lambda has already compute D_grad_over_c = LorR*D*LorR'
D_grad_over_c.scalarMultiply(this->m_lambda*this->m_lambda);
D_grad_over_c -= this->data; D_grad_over_c -= this->data;
//TODO: opt. by determining best order of product //TODO: opt. by determining best order of product
// tmp = LorR'*(LorR*D*LorR' - X) // tmp = Uhat'*(Uhat*D*Uhat' - X)
gemm(this->LorR, D_grad_over_c, tmp, (FPP) 1., (FPP) 0., this->TorH, 'N', this->blas_handle); gemm(this->LorR, D_grad_over_c, tmp, (FPP) this->m_lambda, (FPP) 0., this->TorH, 'N', this->blas_handle);
// D_grad_over_c = LorR'*(LorR*D*LorR' - X)*LorR // D_grad_over_c = 0.5*Uhat'*(Uhat*D*Uhat' - X)*Uhat
gemm(tmp, this->LorR, D_grad_over_c, (FPP) 1., (FPP) 0., 'N', 'N', this->blas_handle); gemm(tmp, this->LorR, D_grad_over_c, (FPP) .5*this->m_lambda/this->c, (FPP) 0., 'N', 'N', this->blas_handle);
} }
template <typename FPP, Device DEVICE, typename FPP2> template <typename FPP, Device DEVICE, typename FPP2>
...@@ -219,3 +221,11 @@ const MatDense<FPP, DEVICE>& Palm4MSAFFT<FPP,DEVICE,FPP2>::get_D() ...@@ -219,3 +221,11 @@ const MatDense<FPP, DEVICE>& Palm4MSAFFT<FPP,DEVICE,FPP2>::get_D()
return this->D; return this->D;
} }
template <typename FPP, Device DEVICE, typename FPP2>
void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_c()
{
//do nothing because the Palm4MSAFFT has always a constant step size
this->isCComputed = true;
}
...@@ -30,7 +30,7 @@ namespace Faust ...@@ -30,7 +30,7 @@ namespace Faust
const bool isFactSideLeft = Params<FPP,DEVICE,FPP2>::defaultFactSideLeft, const bool isFactSideLeft = Params<FPP,DEVICE,FPP2>::defaultFactSideLeft,
const FPP init_lambda = Params<FPP,DEVICE,FPP2>::defaultLambda, const FPP init_lambda = Params<FPP,DEVICE,FPP2>::defaultLambda,
const bool constant_step_size = Params<FPP,DEVICE,FPP2>::defaultConstantStepSize, const bool constant_step_size = Params<FPP,DEVICE,FPP2>::defaultConstantStepSize,
const FPP step_size = Params<FPP,DEVICE,FPP2>::defaultStepSize): Params<FPP, DEVICE, FPP2>(nbRow, nbCol, nbFact, cons, init_fact, stop_crit_2facts, stop_crit_global, isVerbose, isFactSideLeft, init_lambda, constant_step_size, step_size), init_D(init_D) const FPP step_size = Params<FPP,DEVICE,FPP2>::defaultStepSize): Params<FPP, DEVICE, FPP2>(nbRow, nbCol, nbFact, cons, init_fact, stop_crit_2facts, stop_crit_global, isVerbose, isFactSideLeft, init_lambda, true, step_size), init_D(init_D)
{ {
} }
......
...@@ -12,6 +12,7 @@ namespace Faust ...@@ -12,6 +12,7 @@ namespace Faust
template<typename FPP, Device DEVICE, typename FPP2 = double> template<typename FPP, Device DEVICE, typename FPP2 = double>
class ParamsPalmFFT : public Faust::ParamsPalm<FPP,DEVICE,FPP2> class ParamsPalmFFT : public Faust::ParamsPalm<FPP,DEVICE,FPP2>
{ {
public: public:
//ctor definitions in header because it consists mainly to call parent ctor //ctor definitions in header because it consists mainly to call parent ctor
...@@ -24,8 +25,7 @@ namespace Faust ...@@ -24,8 +25,7 @@ namespace Faust
const bool isVerbose_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultVerbosity , const bool isVerbose_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultVerbosity ,
const bool isUpdateWayR2L_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultUpdateWayR2L , const bool isUpdateWayR2L_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultUpdateWayR2L ,
const FPP init_lambda_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultLambda, const FPP init_lambda_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultLambda,
const bool constant_step_size_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultConstantStepSize, const FPP step_size_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultStepSize) : ParamsPalm<FPP, DEVICE, FPP2>(data_, nbFact_, cons_, init_fact_, stop_crit_, isVerbose_, isUpdateWayR2L_, init_lambda_, true /*constant_step_size is always true for Palm4MSAFFT */, step_size_), init_D(init_D) {}
const FPP step_size_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultStepSize) : ParamsPalm<FPP, DEVICE, FPP2>(data_, nbFact_, cons_, init_fact_, stop_crit_, isVerbose_, isUpdateWayR2L_, init_lambda_, constant_step_size_, step_size_), init_D(init_D) {}
ParamsPalmFFT() : ParamsPalm<FPP,DEVICE,FPP2>(), init_D(0,0) {} ParamsPalmFFT() : ParamsPalm<FPP,DEVICE,FPP2>(), init_D(0,0) {}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment