Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 7f9c00e0 authored by hhakim's avatar hhakim
Browse files

Implement Palm4MSAFFT::compute_grad_over_c().

Fully implement Palm4MSAFFT (in waiting for validation).

- Correcting the initialization of init_D in ParamsPalmFFT.
- Overridding Palm4MSA::next_step(),
- Defining the compute_D() and compute_D_grad_over_c() (functions proper to Palm4MSAFFT).
- Completing the function Palm4MSAFFT::compute_lambda() relying on parent's compute_lambda().

- Updating the test test_palm4MSAFFT.cpp.in (the matrix to factorize must be square).

- Minor changes in Palm4MSA: just a comment and indenting.
parent a58590e6
Branches
Tags
No related merge requests found
......@@ -60,13 +60,21 @@ typedef @TEST_FPP2@ FPP2;
*/
int main()
{
Faust::MatDense<FPP,Cpu> data, initFacts1, initFacts2;
Faust::MatDense<FPP,Cpu> data, orig_data;
//TODO: this test is just a sketch to verify WIP and is to re-define properly later
char configPalm2Filename[] = "@FAUST_DATA_MAT_DIR@/config_compared_palm2.mat";
init_faust_mat_from_matio(data, configPalm2Filename, "data");
init_faust_mat_from_matio(initFacts1, configPalm2Filename, "init_facts1");
init_faust_mat_from_matio(initFacts2, configPalm2Filename, "init_facts2");
init_faust_mat_from_matio(orig_data, configPalm2Filename, "data");
// we only used a 32x32 part of the data from .mat file because in Palm4MSAFFT
// we work on square matrix (FFT)
data = Faust::MatDense<FPP,Cpu>(orig_data.getData(), 32,32);
data.scalarMultiply((FPP).2);
Faust::MatDense<FPP, Cpu> initFacts1(data.getNbRow(), data.getNbCol());
Faust::MatDense<FPP, Cpu> initFacts2(data.getNbRow(), data.getNbCol());
initFacts1.setEyes();
initFacts2.setEyes();
int cons1Name, cons1Parameter, cons1Row, cons1Col;
int cons2Name, cons2Row, cons2Col;
......@@ -77,14 +85,13 @@ int main()
cons1Name = init_int_from_matio(configPalm2Filename, "cons1_name");
cons1Parameter = init_int_from_matio(configPalm2Filename, "cons1_parameter");
cons1Row = init_int_from_matio(configPalm2Filename, "cons1_row");
cons1Col = init_int_from_matio(configPalm2Filename, "cons1_col");
cons1Row = initFacts1.getNbRow();
cons1Col = initFacts1.getNbCol();
cons2Name = init_int_from_matio(configPalm2Filename, "cons2_name");
cons2Parameter = (FPP2) init_double_from_matio(configPalm2Filename, "cons2_parameter");
cons2Row = init_int_from_matio(configPalm2Filename, "cons2_row");
cons2Col = init_int_from_matio(configPalm2Filename, "cons2_col");
cons2Row = initFacts2.getNbRow();
cons2Col = initFacts2.getNbCol();
initLambda = (FPP) init_double_from_matio(configPalm2Filename, "init_lambda");
nfacts = init_int_from_matio(configPalm2Filename, "nfacts");
......@@ -95,7 +102,8 @@ int main()
// Creation du vecteur de contrainte
const Faust::ConstraintInt<FPP,Cpu> cons1(static_cast<faust_constraint_name>(cons1Name), cons1Parameter, cons1Row, cons1Col);
const Faust::ConstraintFPP<FPP,Cpu,FPP2> cons2(static_cast<faust_constraint_name>(cons2Name), cons2Parameter, cons2Row, cons2Col);
const Faust::ConstraintFPP<FPP,Cpu,FPP2> cons2(static_cast<faust_constraint_name>(cons2Name), cons2Parameter, cons2Row, cons2Col);
vector<const Faust::ConstraintGeneric*> cons;
cons.push_back(&cons1);
......@@ -117,8 +125,6 @@ int main()
Faust::Palm4MSAFFT<FPP,Cpu,FPP2> palm2(params,blasHandle,true);
// Faust::MatDense<FPP,Cpu> eye = Faust::MatDense<FPP,Cpu>::eye(10,10);
// eye.Display();
// palm2.next_step();
......
......@@ -122,7 +122,7 @@ namespace Faust
* whereas all the other are set to identity
*/
void init_fact(int nb_facts_);
void next_step();
virtual void next_step();
bool do_continue(){bool cont=stop_crit.do_continue(++m_indIte); if(!cont){m_indIte=-1;isConstraintSet=false;}return cont;} // CAUTION !!! pre-increment of m_indIte: the value in stop_crit.do_continue is m_indIte+1, not m_indIte
//bool do_continue()const{return stop_crit.do_continue(++m_indIte, error);};
......
......@@ -375,8 +375,8 @@ t_local_compute_grad_over_c.stop();
template<typename FPP, Device DEVICE, typename FPP2>
void Faust::Palm4MSA<FPP,DEVICE,FPP2>::compute_xt_xhat(Faust::MatDense<FPP,DEVICE>& Xt_Xhat)
{
//TODO: replace by two functions to point to to avoid the comprison at each iteration
//TODO: replate compute_xt_xhat by a function pointer
//TODO: replace by two functions to point to, to avoid the comparison at each iteration
//TODO: replace compute_xt_xhat by a function pointer
if(typeid(FPP) == typeid(complex<double>) || typeid(FPP) == typeid(complex<float>)){
MatDense<FPP,DEVICE> data_cpy = data;
data_cpy.conjugate(false);
......@@ -390,8 +390,8 @@ void Faust::Palm4MSA<FPP,DEVICE,FPP2>::compute_xt_xhat(Faust::MatDense<FPP,DEVIC
template<typename FPP, Device DEVICE, typename FPP2>
void Faust::Palm4MSA<FPP,DEVICE,FPP2>::compute_xhatt_xhat(Faust::MatDense<FPP,DEVICE>& Xhatt_Xhat) {
//TODO: replace by two functions to point to to avoid the comprison at each iteration
//TODO: replate compute_xhatt_xhat by a function pointer
//TODO: replace by two functions to point to, to avoid the comparison at each iteration
//TODO: replace compute_xhatt_xhat by a function pointer
if(typeid(FPP) == typeid(complex<double>) || typeid(FPP) == typeid(complex<float>)){
Faust::MatDense<FPP,DEVICE> tmp_LoR = LorR;
tmp_LoR.conjugate(false);
......@@ -675,10 +675,10 @@ t_local_next_step.start();
int* ind_ptr = new int[m_nbFact];
for (int j=0 ; j<m_nbFact ; j++)
if (!isUpdateWayR2L)
ind_ptr[j] = j;
else
ind_ptr[j] = m_nbFact-1-j;
if (!isUpdateWayR2L)
ind_ptr[j] = j;
else
ind_ptr[j] = m_nbFact-1-j;
for (int j=0 ; j<m_nbFact ; j++)
{
......
......@@ -12,13 +12,16 @@ namespace Faust {
class Palm4MSAFFT : public Palm4MSA<FPP, DEVICE, FPP2>
{
MatDense<FPP, DEVICE> D; //TODO: later it will need to be Sparse (which needs to add a prototype overload for multiplication in faust_linear_algebra.h)
Faust::MatDense<FPP,DEVICE> D_grad_over_c; //TODO: move to sparse mat later
public:
//TODO: another ctor (like in Palm4MSA) for hierarchical algo. use
Palm4MSAFFT(const ParamsPalmFFT<FPP, DEVICE, FPP2>& params, const BlasHandle<DEVICE> blasHandle, const bool isGlobal=false);
virtual void next_step();
private:
virtual void compute_grad_over_c();
virtual void compute_lambda();
void compute_D();
void compute_D_grad_over_c();
};
#include "faust_Palm4MSAFFT.hpp"
......
#include <cstring>
template <typename FPP, Device DEVICE, typename FPP2>
Palm4MSAFFT<FPP,DEVICE,FPP2>::Palm4MSAFFT(const ParamsPalmFFT<FPP, DEVICE, FPP2>& params, const BlasHandle<DEVICE> blasHandle, const bool isGlobal) : Palm4MSA<FPP,DEVICE,FPP2>(params, blasHandle, isGlobal), D(params.init_D)
{
//TODO: manage init_D ?
//TODO: is there something to check additionally to what parent's ctor checks ?
}
......@@ -157,5 +157,63 @@ template <typename FPP, Device DEVICE, typename FPP2>
void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_lambda()
{
//TODO: override parent's method
// Xhat = (S[0]*...*S[nfact-1])*D*(S[0]*...*S[nfact-1])'
// Xhat = LorR*D*LorR' // LorR equals the prod of all factors after their update iterations (in loop of next_step())
MatDense<FPP,Cpu> tmp;
// tmp = D*LorR'
gemm(this->D, this->LorR, tmp, (FPP) 1.0, (FPP) 0.0, 'N', 'T', this->blas_handle);
// LorR = LorR*tmp
gemm(this->LorR, tmp, D_grad_over_c, (FPP) 1.0, (FPP) 0.0, 'N', 'N', this->blas_handle);
tmp = this->LorR;
this->LorR = D_grad_over_c;
//NOTE: D_grad_over_c has nothing to do here but is equal to LorR*D*LorR*D'
// this product is thus not re-computed in compute_D_grad_over_c()
//TODO: avoid all these copies
// at this stage we can rely on parent function to compute lambda
Palm4MSA<FPP,DEVICE,FPP2>::compute_lambda();
// reset LorR at the factor product to continue next iterations
this->LorR = tmp;
//then we finish the lambda computation with a sqrt() (Fro. norm)
this->m_lambda = std::sqrt(this->m_lambda);
// (that's an additional operation in Palm4MSAFFT)
}
template <typename FPP, Device DEVICE, typename FPP2>
void Palm4MSAFFT<FPP,DEVICE,FPP2>::next_step()
{
Palm4MSA<FPP, Cpu, FPP2>::next_step();
// besides to what the parent has done
// we need to update D
this->compute_D();
}
template <typename FPP, Device DEVICE, typename FPP2>
void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_D()
{
// besides to what the parent has done
// we need to update D
compute_D_grad_over_c();
D_grad_over_c.scalarMultiply(this->m_lambda/this->c);
D -= D_grad_over_c;
//TODO: optimize MatSparse + no-copy (Eigen::DiagonalMatrix ?)
FPP * data = new FPP[D.getNbRow()*D.getNbCol()];
memset(data, 0, sizeof(FPP)*D.getNbRow()*D.getNbCol());
for(faust_unsigned_int i = 0; i < D.getNbCol();i++)
data[i*D.getNbCol()+i] = D[i*D.getNbCol()+i];
D = MatDense<FPP,Cpu>(data, D.getNbRow(), D.getNbCol());
}
template <typename FPP, Device DEVICE, typename FPP2>
void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_D_grad_over_c()
{
// grad = 0.5*LorR'*(LorR*D*LorR' - X)*LorR
MatDense<FPP, Cpu> tmp;
//compute_lambda has already compute D_grad_over_c = LorR*D*LorR'
D_grad_over_c -= this->data;
//TODO: opt. by determining best order of product
// tmp = LorR'*(LorR*D*LorR' - X)
gemm(this->LorR, D_grad_over_c, tmp, (FPP) 1., (FPP) 0., 'T', 'N', this->blas_handle);
// D_grad_over_c = LorR'*(LorR*D*LorR' - X)*LorR
gemm(tmp, this->LorR, D_grad_over_c, (FPP) 1., (FPP) 0., 'N', 'N', this->blas_handle);
}
......@@ -24,7 +24,7 @@ namespace Faust
const bool isUpdateWayR2L_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultUpdateWayR2L ,
const FPP init_lambda_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultLambda,
const bool constant_step_size_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultConstantStepSize,
const FPP step_size_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultStepSize) : ParamsPalm<FPP, DEVICE, FPP2>(data_, nbFact_, cons_, init_fact_, stop_crit_, isVerbose_, isUpdateWayR2L_, init_lambda_, constant_step_size_, step_size_), init_D(MatDense<FPP,DEVICE>::eye(data_.getNbRow(), data_.getNbCol())) {}
const FPP step_size_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultStepSize) : ParamsPalm<FPP, DEVICE, FPP2>(data_, nbFact_, cons_, init_fact_, stop_crit_, isVerbose_, isUpdateWayR2L_, init_lambda_, constant_step_size_, step_size_), init_D(MatDense<FPP,DEVICE>::eye(data_.getNbCol(), data_.getNbCol())) {}
ParamsPalmFFT() : ParamsPalm<FPP,DEVICE,FPP2>(), init_D(0,0) {}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment