Mentions légales du service

Skip to content
Snippets Groups Projects
Commit a58590e6 authored by hhakim's avatar hhakim
Browse files

Implement Palm4MSAFFT::compute_grad_over_c().

- It needs to be checked again but it should work.
- It's based on CodeLuc/utils/grad_FFTgraph_S.m.
- The optimization is not complete (it just follows the opt. from Palm4MSA parent).
- Attribute D (matrix to diagonalize) added in Palm4MSAFFT.
- Other minor changes in faust_linear_algebra (comment complement).
parent 372729a1
No related branches found
No related tags found
No related merge requests found
......@@ -112,7 +112,7 @@ namespace Faust
FPP get_RMSE()const{return Faust::fabs(error.norm())/sqrt((double)(data.getNbRow()*data.getNbCol()));}
const Faust::MatDense<FPP,DEVICE>& get_res(bool isFactSideLeft_, int ind_)const{return isFactSideLeft_ ? S[0] : S[ind_+1];}
const Faust::MatDense<FPP,DEVICE>& get_data()const{return data;}
void get_facts(Faust::Transform<FPP,DEVICE> & faust_fact) const;
void get_facts(Faust::Transform<FPP,DEVICE> & faust_fact) const;
/*!
......@@ -151,7 +151,7 @@ namespace Faust
Faust::StoppingCriterion<FPP2> stop_crit;
private:
protected:
// modif AL AL
Faust::MatDense<FPP,DEVICE> data;
......
......@@ -11,12 +11,14 @@ namespace Faust {
template<typename FPP, Device DEVICE, typename FPP2 = double>
class Palm4MSAFFT : public Palm4MSA<FPP, DEVICE, FPP2>
{
MatDense<FPP, DEVICE> D; //TODO: later it will need to be Sparse (which needs to add a prototype overload for multiplication in faust_linear_algebra.h)
public:
//TODO: another ctor (like in Palm4MSA) for hierarchical algo. use
Palm4MSAFFT(const ParamsPalmFFT<FPP, DEVICE, FPP2>& params, const BlasHandle<DEVICE> blasHandle, const bool isGlobal=false);
private:
virtual void compute_grad_over_c();
virtual void compute_lambda();
};
#include "faust_Palm4MSAFFT.hpp"
......
template <typename FPP, Device DEVICE, typename FPP2>
Palm4MSAFFT<FPP,DEVICE,FPP2>::Palm4MSAFFT(const ParamsPalmFFT<FPP, DEVICE, FPP2>& params, const BlasHandle<DEVICE> blasHandle, const bool isGlobal) : Palm4MSA<FPP,DEVICE,FPP2>(params, blasHandle, isGlobal)
Palm4MSAFFT<FPP,DEVICE,FPP2>::Palm4MSAFFT(const ParamsPalmFFT<FPP, DEVICE, FPP2>& params, const BlasHandle<DEVICE> blasHandle, const bool isGlobal) : Palm4MSA<FPP,DEVICE,FPP2>(params, blasHandle, isGlobal), D(params.init_D)
{
//TODO: manage init_D ?
......@@ -10,8 +10,146 @@ Palm4MSAFFT<FPP,DEVICE,FPP2>::Palm4MSAFFT(const ParamsPalmFFT<FPP, DEVICE, FPP2>
template <typename FPP, Device DEVICE, typename FPP2>
void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_grad_over_c()
{
//TODO: override parent's method
Palm4MSA<FPP,DEVICE,FPP2>::compute_grad_over_c();
if(!this->isCComputed)
{
throw std::logic_error("this->c must be set before computing grad/this->c");
}
/*! \brief There are 4 ways to compute gradient : <br>
* (0) : lambda*(L'*(lambda*(L*this->S)*R - X))*R' : complexity = L1*L2*S2 + L1*S2*R2 + L2*L1*R2 + L2*R2*R2; <br>
* (1) : lambda*L'*((lambda*(L*this->S)*R - X)*R') : complexity = L1*L2*S2 + L1*S2*R2 + L1*R2*S2 + L2*L1*S2; <br>
* (2) : lambda*(L'*(lambda*L*(this->S*R) - X))*R' : complexity = L2*S2*R2 + L1*L2*R2 + L2*L1*R2 + L2*R2*S2; <br>
* (3) : lambda*L'*((lambda*L*(this->S*R) - X)*R') : complexity = L2*S2*R2 + L1*L2*R2 + L1*R2*S2 + L2*L1*S2; <br>
* with L of size L1xL2 <br>
* this->S of size L2xS2 <br>
* R of size S2xR2 <br>
*/
unsigned long long int L1, L2, R2, S2;
if (!this->isUpdateWayR2L)
{
L1 = (unsigned long long int) this->LorR.getNbRow();
L2 = (unsigned long long int) this->LorR.getNbCol();
R2 = (unsigned long long int) this->RorL[this->m_indFact].getNbCol();
}
else
{
L1 = (unsigned long long int) this->RorL[this->m_indFact].getNbRow();
L2 = (unsigned long long int) this->RorL[this->m_indFact].getNbCol();
R2 = (unsigned long long int) this->LorR.getNbCol();
}
S2 = (unsigned long long int) this->S[this->m_indFact].getNbCol();
vector<unsigned long long int > complexity(4,0);
complexity[0] = L1*L2*S2 + L1*S2*R2 + L2*L1*R2 + L2*R2*R2;
complexity[1] = L1*L2*S2 + L1*S2*R2 + L1*R2*S2 + L2*L1*S2;
complexity[2] = L2*S2*R2 + L1*L2*R2 + L2*L1*R2 + L2*R2*S2;
complexity[3] = L2*S2*R2 + L1*L2*R2 + L1*R2*S2 + L2*L1*S2;
int idx = distance(complexity.begin(), min_element(complexity.begin(), complexity.end()));
this->error = this->data;
Faust::MatDense<FPP,DEVICE> tmp1,tmp2,tmp3;
if (idx==0 || idx==1) // computing L*this->S first, then (L*this->S)*R and finally the this->error
{
if (!this->isUpdateWayR2L)
{
// tmp1 = L*this->S
multiply(this->LorR, this->S[this->m_indFact], tmp1, this->blas_handle);
// tmp2 = L*this->S*R
multiply(tmp1, this->RorL[this->m_indFact], tmp2, this->blas_handle);
}
else
{
// tmp1 = L*this->S
multiply(this->RorL[this->m_indFact], this->S[this->m_indFact], tmp1, this->blas_handle);
// tmp2 = L*this->S*R
multiply(tmp1, this->LorR, tmp2, this->blas_handle);
}
}
else // computing this->S*R first, then L*(this->S*R)
{
if (!this->isUpdateWayR2L)
{
// tmp1 = this->S*R
multiply(this->S[this->m_indFact], this->RorL[this->m_indFact], tmp1, this->blas_handle);
// tmp2 = L*this->S*R
multiply(this->LorR, tmp1, tmp2, this->blas_handle);
}
else
{
// tmp1 = this->S*R
multiply(this->S[this->m_indFact], this->LorR, tmp1, this->blas_handle);
// tmp2 = L*this->S*R
multiply(this->RorL[this->m_indFact], tmp1, tmp2, this->blas_handle);
}
}
// tmp1 = L*this->S*R*D //TODO: review the mul with D being MatSparse
//TODO: optimize by determining the best product order regarding computation time
multiply(tmp2, D, tmp1, this->blas_handle);
// this->error = lambda*tmp1*lambda*tmp2'-data // this->error is data before this call
gemm(tmp1, tmp2, this->error, this->m_lambda*this->m_lambda, (FPP)-1.0, 'N', 'T', this->blas_handle);
//false is for disabling evaluation (because the transpose does it later)
this->LorR.conjugate(false);
this->RorL[this->m_indFact].conjugate(false);
if (idx==0 || idx==2) // computing L'*this->error first, then (L'*this->error)*R'
{
if (!this->isUpdateWayR2L)
{
// tmp3 = this->m_lambda*L'*this->error (= this->m_lambda*L' * (this->m_lambda*L*this->S*R - data) )
gemm(this->LorR, this->error, tmp3, this->m_lambda,(FPP) 0.0, 'T', 'N', this->blas_handle);
// tmp2 = lambda*L*this->S*R*D*R'
gemm(tmp1, this->RorL[this->m_indFact], tmp2, this->m_lambda, (FPP) 0, 'N', 'T', this->blas_handle);
}
else
{
// tmp3 = this->m_lambda*L'*this->error (= this->m_lambda*L' * (this->m_lambda*L*this->S*R - data) )
gemm(this->RorL[this->m_indFact], this->error, tmp3, this->m_lambda, (FPP) 0.0, 'T', 'N', this->blas_handle);
// tmp2 = lambda*L*this->S*R*D*R'
gemm(tmp1, this->LorR, tmp2, this->m_lambda, (FPP) 0, 'N', 'T', this->blas_handle);
}
// grad_over_c = 1/this->c*tmp3*tmp2
gemm(tmp3, tmp2, this->grad_over_c, (FPP) 1.0/this->c, (FPP) (FPP) 0.0,'N','N', this->blas_handle);
}
else // computing this->error*R' first, then L'*(this->error*lambda*LSRD*R')
{
if (!this->isUpdateWayR2L)
{
// tmp2 = lambda*tmp1*R' = lambda*LSRD*R'
gemm(tmp1, this->RorL[this->m_indFact], tmp2, (FPP) this->m_lambda, (FPP) 0, 'N', 'T', this->blas_handle);
// tmp3 = this->m_lambda*this->error*tmp2
gemm(this->error, tmp2, tmp3, this->m_lambda, (FPP) 0.0, 'N', 'N', this->blas_handle);
// grad_over_c = 1/this->c*L'*tmp3
gemm(this->LorR, tmp3, this->grad_over_c,(FPP) 1.0/this->c, (FPP) 0.0,'T','N', this->blas_handle);
}
else
{
// tmp2 = lambda*tmp1*R' = lambda*LSRD*R'
gemm(tmp1, this->LorR, tmp2, (FPP) this->m_lambda, (FPP) 0, 'N', 'T', this->blas_handle);
// tmp3 = this->m_lambda*this->error*tmp2
gemm(this->error, tmp2, tmp3, this->m_lambda, (FPP) 0.0, 'N', 'N', this->blas_handle);
// grad_over_c = 1/this->c*L'*tmp3
gemm(this->RorL[this->m_indFact], tmp3, this->grad_over_c, (FPP) 1.0/this->c, (FPP) 0.0,'T','N', this->blas_handle);
}
}
//TODO: avoid type checking by adding another template function (for complex and real types) or function pointer
if(typeid(this->data.getData()[0]) == typeid(complex<float>) || typeid(this->data.getData()[0]) == typeid(complex<double>))
{
this->LorR.conjugate();
this->RorL[this->m_indFact].conjugate();
}
this->isGradComputed = true;
#ifdef __COMPILE_TIMERS__
t_global_compute_grad_over_c.stop();
t_local_compute_grad_over_c.stop();
#endif
}
......
......@@ -89,7 +89,7 @@ namespace Faust
//////////FONCTION Faust::MatDense<FPP,Cpu> - Faust::MatDense<FPP,Cpu> ////////////////////
// C = A * B;
//l'objet C doit etre different de A et B
//l'objet C doit etre different de A et B (but gemm() manages it with two copies)
//! \fn multiply
//! \brief Multiplication C = A * B
//! \warning Object C must be different of A and B.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment