Mentions légales du service

Skip to content
Snippets Groups Projects
Commit b10edd22 authored by hhakim's avatar hhakim
Browse files

Handle Hermitian matrix product in gemm_core() function (faust_linear_algebra)...

Handle Hermitian matrix product in gemm_core() function (faust_linear_algebra) and use it in Palm4MSA(FFT) algos.

- Before: we had 'T' for transpose matrix and 'N' normal matrix. Now: we have the additional 'H' for transconjugate matrix.
- This is not yet handled when using gemm BLAS computation (it's only for eigen and it raises an exception if using BLAS with 'H').
- This is for handling more properly Palm4MSA matrix complex approximation.
- Palm4MSA is modified, optimized and clarified by using this new capability.
parent 3fcf374f
No related branches found
No related tags found
No related merge requests found
......@@ -132,8 +132,6 @@ namespace Faust
*/
void init_fact_from_palm(const Palm4MSA& palm, bool isFactSideLeft);
const std::vector<Faust::MatDense<FPP,DEVICE> >& get_facts()const {return S;}
void compute_xt_xhat(MatDense<FPP,DEVICE>& Xt_Xhat);
void compute_xhatt_xhat(MatDense<FPP,DEVICE>& Xt_Xhat);
~Palm4MSA(){}
protected:
......@@ -189,8 +187,10 @@ namespace Faust
FPP c;
Faust::MatDense<FPP,DEVICE> error; // error = lambda*L*S*R - data
Faust::BlasHandle<DEVICE> blas_handle;
/** is_complex == true if the algorithm is running on a complex matrix (to approximate) */
bool is_complex;
/** TorH == 'T' if this->is_complex == false otherwise it's 'H'. T designates the transposition and H the hermitian matrix, it intervenes in Palm4MSA algorithms for the computation of the gradient and lambda so that the algo. uses the hermitian when working on complex matrices (i.e. the matrix to approx. is complex) */
const char TorH;
......
......@@ -99,7 +99,10 @@ Faust::Palm4MSA<FPP,DEVICE,FPP2>::Palm4MSA(const Faust::MatDense<FPP,DEVICE>& M,
isGlobal(isGlobal_),
isInit(false),
c(FPP(1)/params_.step_size),
blas_handle(blasHandle)
blas_handle(blasHandle),
is_complex(typeid(data.getData()[0]) == typeid(complex<float>) || typeid(data.getData()[0]) == typeid(complex<double>)
),
TorH(is_complex?'H':'T')
{
RorL.reserve(params_.m_nbFact);
......@@ -136,7 +139,10 @@ Faust::Palm4MSA<FPP,DEVICE,FPP2>::Palm4MSA(const Faust::ParamsPalm<FPP,DEVICE,FP
isConstraintSet(false),
isGlobal(isGlobal_),
c(FPP(1)/params_palm_.step_size),
blas_handle(blasHandle)
blas_handle(blasHandle),
is_complex(typeid(data.getData()[0]) == typeid(complex<float>) || typeid(data.getData()[0]) == typeid(complex<double>)
),
TorH(is_complex?'H':'T')
{
RorL.reserve(const_vec.size()+1);
......@@ -327,16 +333,16 @@ sprintf(nomFichier,"error_1_%d_device.tmp",cmpt);*/
if (!isUpdateWayR2L)
{
// tmp3 = m_lambda*L'*error (= m_lambda*L' * (m_lambda*L*S*R - data) )
gemm(LorR, error, tmp3, m_lambda,(FPP) 0.0, 'T', 'N', blas_handle);
gemm(LorR, error, tmp3, m_lambda,(FPP) 0.0, TorH, 'N', blas_handle);
// grad_over_c = 1/c*tmp3*R' (= 1/c*m_lambda*L' * (m_lambda*L*S*R - data) * R' )
gemm(tmp3, RorL[m_indFact], grad_over_c,(FPP) 1.0/c,(FPP) 0.0,'N','T', blas_handle);
gemm(tmp3, RorL[m_indFact], grad_over_c,(FPP) 1.0/c,(FPP) 0.0,'N',TorH, blas_handle);
}
else
{
// tmp3 = m_lambda*L'*error (= m_lambda*L' * (m_lambda*L*S*R - data) )
gemm(RorL[m_indFact], error, tmp3, m_lambda, (FPP) 0.0, 'T', 'N', blas_handle);
gemm(RorL[m_indFact], error, tmp3, m_lambda, (FPP) 0.0, TorH, 'N', blas_handle);
// grad_over_c = 1/c*tmp3*R' (= 1/c*m_lambda*L' * (m_lambda*L*S*R - data) * R' )
gemm(tmp3, LorR, grad_over_c, (FPP) 1.0/c, (FPP) (FPP) 0.0,'N','T', blas_handle);
gemm(tmp3, LorR, grad_over_c, (FPP) 1.0/c, (FPP) (FPP) 0.0,'N',TorH, blas_handle);
}
}
else // computing error*R' first, then L'*(error*R')
......@@ -344,26 +350,20 @@ sprintf(nomFichier,"error_1_%d_device.tmp",cmpt);*/
if (!isUpdateWayR2L)
{
// tmp3 = m_lambda*error*R' (= m_lambda*(m_lambda*L*S*R - data) * R' )
gemm(error, RorL[m_indFact], tmp3, m_lambda, (FPP) 0.0, 'N', 'T', blas_handle);
gemm(error, RorL[m_indFact], tmp3, m_lambda, (FPP) 0.0, 'N', TorH, blas_handle);
// grad_over_c = 1/c*L'*tmp3 (= 1/c*L' * m_lambda*(m_lambda*L*S*R - data) * R' )
gemm(LorR, tmp3, grad_over_c,(FPP) 1.0/c, (FPP) 0.0,'T','N', blas_handle);
gemm(LorR, tmp3, grad_over_c,(FPP) 1.0/c, (FPP) 0.0,TorH,'N', blas_handle);
}
else
{
// tmp3 = m_lambda*error*R' (= m_lambda * (m_lambda*L*S*R - data) * R' )
gemm(error, LorR, tmp3, m_lambda, (FPP) 0.0, 'N', 'T', blas_handle);
gemm(error, LorR, tmp3, m_lambda, (FPP) 0.0, 'N', TorH, blas_handle);
// grad_over_c = 1/c*L'*tmp3 (= 1/c*L' * m_lambda*(m_lambda*L*S*R - data) * R' )
gemm(RorL[m_indFact], tmp3, grad_over_c, (FPP) 1.0/c, (FPP) 0.0,'T','N', blas_handle);
gemm(RorL[m_indFact], tmp3, grad_over_c, (FPP) 1.0/c, (FPP) 0.0,TorH,'N', blas_handle);
}
}
//TODO: avoid type checking by adding another template function (for complex and real types) or function pointer
if(typeid(data.getData()[0]) == typeid(complex<float>) || typeid(data.getData()[0]) == typeid(complex<double>))
{
LorR.conjugate();
RorL[m_indFact].conjugate();
}
isGradComputed = true;
#ifdef __COMPILE_TIMERS__
......@@ -372,39 +372,6 @@ t_local_compute_grad_over_c.stop();
#endif
}
template<typename FPP, Device DEVICE, typename FPP2>
void Faust::Palm4MSA<FPP,DEVICE,FPP2>::compute_xt_xhat(Faust::MatDense<FPP,DEVICE>& Xt_Xhat)
{
//TODO: replace by two functions to point to, to avoid the comparison at each iteration
//TODO: replace compute_xt_xhat by a function pointer
if(typeid(FPP) == typeid(complex<double>) || typeid(FPP) == typeid(complex<float>)){
MatDense<FPP,DEVICE> data_cpy = data;
data_cpy.conjugate(false);
gemm(data_cpy, LorR, Xt_Xhat, (FPP) 1.0, (FPP) 0.0, 'T','N', blas_handle);
Xt_Xhat.real();
}
else {
gemm(data, LorR, Xt_Xhat, (FPP) 1.0, (FPP) 0.0, 'T','N', blas_handle);
}
}
template<typename FPP, Device DEVICE, typename FPP2>
void Faust::Palm4MSA<FPP,DEVICE,FPP2>::compute_xhatt_xhat(Faust::MatDense<FPP,DEVICE>& Xhatt_Xhat) {
//TODO: replace by two functions to point to, to avoid the comparison at each iteration
//TODO: replace compute_xhatt_xhat by a function pointer
if(typeid(FPP) == typeid(complex<double>) || typeid(FPP) == typeid(complex<float>)){
Faust::MatDense<FPP,DEVICE> tmp_LoR = LorR;
tmp_LoR.conjugate(false);
gemm(tmp_LoR, LorR, Xhatt_Xhat, (FPP) 1.0, (FPP) 0.0, 'T','N',blas_handle);
// gemm(LorR, LorR, Xhatt_Xhat, (FPP) 1.0, (FPP) 0.0, 'T','N',blas_handle);
Xhatt_Xhat.real();
}
else {
gemm(LorR, LorR, Xhatt_Xhat, (FPP) 1.0, (FPP) 0.0, 'T','N',blas_handle);
}
}
template<typename FPP,Device DEVICE,typename FPP2>
void Faust::Palm4MSA<FPP,DEVICE,FPP2>::compute_lambda()
{
......@@ -421,10 +388,10 @@ t_local_compute_lambda.start();
// As LorR has also been updated at the end of the last iteration over the facts, LorR matches X_hat, which the product of all factors, including the last one.
// Xt_Xhat = data'*X_hat
Faust::MatDense<FPP,DEVICE> Xt_Xhat;
this->compute_xt_xhat(Xt_Xhat);
gemm(data, LorR, Xt_Xhat, (FPP) 1.0, (FPP) 0.0, TorH,'N', blas_handle);
// Xhatt_Xhat = X_hat'*X_hat
Faust::MatDense<FPP,DEVICE> Xhatt_Xhat;
this->compute_xhatt_xhat(Xhatt_Xhat);
gemm(LorR, LorR, Xhatt_Xhat, (FPP) 1.0, (FPP) 0.0, TorH,'N',blas_handle);
FPP Xhatt_Xhat_tr = (FPP) Xhatt_Xhat.trace();
......
......@@ -93,27 +93,24 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_grad_over_c()
//TODO: optimize by determining the best product order regarding computation time
multiply(tmp2, D, tmp1, this->blas_handle);
// this->error = lambda*tmp1*lambda*tmp2'-data // this->error is data before this call
gemm(tmp1, tmp2, this->error, this->m_lambda*this->m_lambda, (FPP)-1.0, 'N', 'T', this->blas_handle);
gemm(tmp1, tmp2, this->error, this->m_lambda*this->m_lambda, (FPP)-1.0, 'N', this->TorH, this->blas_handle);
//false is for disabling evaluation (because the transpose does it later)
this->LorR.conjugate(false);
this->RorL[this->m_indFact].conjugate(false);
if (idx==0 || idx==2) // computing L'*this->error first, then (L'*this->error)*R'
{
if (!this->isUpdateWayR2L)
{
// tmp3 = this->m_lambda*L'*this->error (= this->m_lambda*L' * (this->m_lambda*L*this->S*R - data) )
gemm(this->LorR, this->error, tmp3, this->m_lambda,(FPP) 0.0, 'T', 'N', this->blas_handle);
gemm(this->LorR, this->error, tmp3, this->m_lambda,(FPP) 0.0, this->TorH, 'N', this->blas_handle);
// tmp2 = lambda*L*this->S*R*D*R'
gemm(tmp1, this->RorL[this->m_indFact], tmp2, this->m_lambda, (FPP) 0, 'N', 'T', this->blas_handle);
gemm(tmp1, this->RorL[this->m_indFact], tmp2, this->m_lambda, (FPP) 0, 'N', this->TorH, this->blas_handle);
}
else
{
// tmp3 = this->m_lambda*L'*this->error (= this->m_lambda*L' * (this->m_lambda*L*this->S*R - data) )
gemm(this->RorL[this->m_indFact], this->error, tmp3, this->m_lambda, (FPP) 0.0, 'T', 'N', this->blas_handle);
gemm(this->RorL[this->m_indFact], this->error, tmp3, this->m_lambda, (FPP) 0.0, this->TorH, 'N', this->blas_handle);
// tmp2 = lambda*L*this->S*R*D*R'
gemm(tmp1, this->LorR, tmp2, this->m_lambda, (FPP) 0, 'N', 'T', this->blas_handle);
gemm(tmp1, this->LorR, tmp2, this->m_lambda, (FPP) 0, 'N', this->TorH, this->blas_handle);
}
// grad_over_c = 1/this->c*tmp3*tmp2
gemm(tmp3, tmp2, this->grad_over_c, (FPP) 1.0/this->c, (FPP) (FPP) 0.0,'N','N', this->blas_handle);
......@@ -124,30 +121,25 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_grad_over_c()
if (!this->isUpdateWayR2L)
{
// tmp2 = lambda*tmp1*R' = lambda*LSRD*R'
gemm(tmp1, this->RorL[this->m_indFact], tmp2, (FPP) this->m_lambda, (FPP) 0, 'N', 'T', this->blas_handle);
gemm(tmp1, this->RorL[this->m_indFact], tmp2, (FPP) this->m_lambda, (FPP) 0, 'N', this->TorH, this->blas_handle);
// tmp3 = this->m_lambda*this->error*tmp2
gemm(this->error, tmp2, tmp3, this->m_lambda, (FPP) 0.0, 'N', 'N', this->blas_handle);
// grad_over_c = 1/this->c*L'*tmp3
gemm(this->LorR, tmp3, this->grad_over_c,(FPP) 1.0/this->c, (FPP) 0.0,'T','N', this->blas_handle);
gemm(this->LorR, tmp3, this->grad_over_c,(FPP) 1.0/this->c, (FPP) 0.0,this->TorH,'N', this->blas_handle);
}
else
{
// tmp2 = lambda*tmp1*R' = lambda*LSRD*R'
gemm(tmp1, this->LorR, tmp2, (FPP) this->m_lambda, (FPP) 0, 'N', 'T', this->blas_handle);
gemm(tmp1, this->LorR, tmp2, (FPP) this->m_lambda, (FPP) 0, 'N', this->TorH, this->blas_handle);
// tmp3 = this->m_lambda*this->error*tmp2
gemm(this->error, tmp2, tmp3, this->m_lambda, (FPP) 0.0, 'N', 'N', this->blas_handle);
// grad_over_c = 1/this->c*L'*tmp3
gemm(this->RorL[this->m_indFact], tmp3, this->grad_over_c, (FPP) 1.0/this->c, (FPP) 0.0,'T','N', this->blas_handle);
gemm(this->RorL[this->m_indFact], tmp3, this->grad_over_c, (FPP) 1.0/this->c, (FPP) 0.0,this->TorH,'N', this->blas_handle);
}
}
//TODO: avoid type checking by adding another template function (for complex and real types) or function pointer
if(typeid(this->data.getData()[0]) == typeid(complex<float>) || typeid(this->data.getData()[0]) == typeid(complex<double>))
{
this->LorR.conjugate();
this->RorL[this->m_indFact].conjugate();
}
this->isGradComputed = true;
#ifdef __COMPILE_TIMERS__
......@@ -165,7 +157,7 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_lambda()
// Xhat = LorR*D*LorR' // LorR equals the prod of all factors after their update iterations (in loop of next_step())
MatDense<FPP,Cpu> tmp;
// tmp = D*LorR'
gemm(this->D, this->LorR, tmp, (FPP) 1.0, (FPP) 0.0, 'N', 'T', this->blas_handle);
gemm(this->D, this->LorR, tmp, (FPP) 1.0, (FPP) 0.0, 'N', this->TorH, this->blas_handle);
// LorR = LorR*tmp
gemm(this->LorR, tmp, D_grad_over_c, (FPP) 1.0, (FPP) 0.0, 'N', 'N', this->blas_handle);
tmp = this->LorR;
......@@ -216,7 +208,7 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_D_grad_over_c()
D_grad_over_c -= this->data;
//TODO: opt. by determining best order of product
// tmp = LorR'*(LorR*D*LorR' - X)
gemm(this->LorR, D_grad_over_c, tmp, (FPP) 1., (FPP) 0., 'T', 'N', this->blas_handle);
gemm(this->LorR, D_grad_over_c, tmp, (FPP) 1., (FPP) 0., this->TorH, 'N', this->blas_handle);
// D_grad_over_c = LorR'*(LorR*D*LorR' - X)*LorR
gemm(tmp, this->LorR, D_grad_over_c, (FPP) 1., (FPP) 0., 'N', 'N', this->blas_handle);
}
......
......@@ -322,6 +322,11 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F
#ifdef __COMPILE_TIMERS__
A.t_gemm.start();
#endif
#ifdef __GEMM_WITH_OPENBLAS__
if(typeA == 'H' || typeB == 'H')
handleError("linear_algebra", " gemm: Hermitian matrix is not yet handled with BLAS.")
#endif
faust_unsigned_int nbRowOpA,nbRowOpB,nbColOpA,nbColOpB;
if ( ((&(C.mat)) == (&(A.mat))) || ((&(C.mat)) == (&(B.mat))) )
......@@ -330,7 +335,7 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F
handleError("linear_algebra", " gemm : C is the same object as A or B");
}
if (typeA == 'T')
if (typeA == 'T' || typeA == 'H')
{
nbRowOpA = A.getNbCol();
nbColOpA = A.getNbRow();
......@@ -341,7 +346,7 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F
}
if (typeB == 'T')
if (typeB == 'T' || typeB == 'H')
{
nbRowOpB = B.getNbCol();
nbColOpB = B.getNbRow();
......@@ -404,6 +409,11 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F
C=B;
if(typeB == 'T')
C.transpose();
else if(typeB == 'H')
{
C.conjugate(false);
C.transpose();
}
if(alpha!=FPP(1.0))
C*= alpha;
C.isZeros = false;
......@@ -418,6 +428,11 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F
C=A;
if(typeA == 'T')
C.transpose();
else if(typeA == 'H')
{
C.conjugate(false);
C.transpose();
}
if(alpha!=FPP(1.0))
C*= alpha;
C.isZeros = false;
......@@ -435,14 +450,27 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F
{
if (typeB == 'N')
C.mat.noalias() = alpha * A.mat * B.mat;
else
else if(typeB == 'T')
C.mat.noalias() = alpha * A.mat * B.mat.transpose();
}else
else // typeB == 'H' //TODO: check validity of typeA and typeB
C.mat.noalias() = alpha * A.mat * B.mat.transpose().conjugate();
}else if(typeA == 'T')
{
if (typeB == 'N')
C.mat.noalias() = alpha * A.mat.transpose() * B.mat;
else
else if(typeB == 'T')
C.mat.noalias() = alpha * A.mat.transpose() * B.mat.transpose();
else // typeB == 'H' //TODO: check validity of typeA and typeB
C.mat.noalias() = alpha * A.mat.transpose() * B.mat.transpose().conjugate();
}
else //if(typeA == 'H')
{
if (typeB == 'N')
C.mat.noalias() = alpha * A.mat.transpose().conjugate() * B.mat;
else if(typeB == 'T')
C.mat.noalias() = alpha * A.mat.transpose().conjugate() * B.mat.transpose();
else // typeB == 'H' //TODO: check validity of typeA and typeB
C.mat.noalias() = alpha * A.mat.transpose().conjugate() * B.mat.transpose().conjugate();
}
#else
FPP beta = FPP(0.0);
......@@ -478,6 +506,11 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F
Faust::MatDense<FPP,Cpu> B_tmp(B);
if(typeB == 'T')
B_tmp.transpose();
else if(typeB == 'H')
{
B_tmp.conjugate(false);
B_tmp.transpose();
}
if(alpha != FPP(1.0))
B_tmp *= alpha;
C += B_tmp;
......@@ -504,6 +537,11 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F
Faust::MatDense<FPP,Cpu> A_tmp(A);
if(typeA == 'T')
A_tmp.transpose();
else if(typeA == 'H')
{
A_tmp.conjugate(false);
A_tmp.transpose();
}
if(alpha != FPP(1.0))
A_tmp *= alpha;
C += A_tmp;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment