Mentions légales du service

Skip to content
Snippets Groups Projects
Commit b10edd22 authored by hhakim's avatar hhakim
Browse files

Handle Hermitian matrix product in gemm_core() function (faust_linear_algebra)...

Handle Hermitian matrix product in gemm_core() function (faust_linear_algebra) and use it in Palm4MSA(FFT) algos.

- Before: we had 'T' for transpose matrix and 'N' normal matrix. Now: we have the additional 'H' for transconjugate matrix.
- This is not yet handled when using gemm BLAS computation (it's only for eigen and it raises an exception if using BLAS with 'H').
- This is for handling more properly Palm4MSA matrix complex approximation.
- Palm4MSA is modified, optimized and clarified by using this new capability.
parent 3fcf374f
Branches
Tags
No related merge requests found
...@@ -132,8 +132,6 @@ namespace Faust ...@@ -132,8 +132,6 @@ namespace Faust
*/ */
void init_fact_from_palm(const Palm4MSA& palm, bool isFactSideLeft); void init_fact_from_palm(const Palm4MSA& palm, bool isFactSideLeft);
const std::vector<Faust::MatDense<FPP,DEVICE> >& get_facts()const {return S;} const std::vector<Faust::MatDense<FPP,DEVICE> >& get_facts()const {return S;}
void compute_xt_xhat(MatDense<FPP,DEVICE>& Xt_Xhat);
void compute_xhatt_xhat(MatDense<FPP,DEVICE>& Xt_Xhat);
~Palm4MSA(){} ~Palm4MSA(){}
protected: protected:
...@@ -189,8 +187,10 @@ namespace Faust ...@@ -189,8 +187,10 @@ namespace Faust
FPP c; FPP c;
Faust::MatDense<FPP,DEVICE> error; // error = lambda*L*S*R - data Faust::MatDense<FPP,DEVICE> error; // error = lambda*L*S*R - data
Faust::BlasHandle<DEVICE> blas_handle; Faust::BlasHandle<DEVICE> blas_handle;
/** is_complex == true if the algorithm is running on a complex matrix (to approximate) */
bool is_complex;
/** TorH == 'T' if this->is_complex == false otherwise it's 'H'. T designates the transposition and H the hermitian matrix, it intervenes in Palm4MSA algorithms for the computation of the gradient and lambda so that the algo. uses the hermitian when working on complex matrices (i.e. the matrix to approx. is complex) */
const char TorH;
......
...@@ -99,7 +99,10 @@ Faust::Palm4MSA<FPP,DEVICE,FPP2>::Palm4MSA(const Faust::MatDense<FPP,DEVICE>& M, ...@@ -99,7 +99,10 @@ Faust::Palm4MSA<FPP,DEVICE,FPP2>::Palm4MSA(const Faust::MatDense<FPP,DEVICE>& M,
isGlobal(isGlobal_), isGlobal(isGlobal_),
isInit(false), isInit(false),
c(FPP(1)/params_.step_size), c(FPP(1)/params_.step_size),
blas_handle(blasHandle) blas_handle(blasHandle),
is_complex(typeid(data.getData()[0]) == typeid(complex<float>) || typeid(data.getData()[0]) == typeid(complex<double>)
),
TorH(is_complex?'H':'T')
{ {
RorL.reserve(params_.m_nbFact); RorL.reserve(params_.m_nbFact);
...@@ -136,7 +139,10 @@ Faust::Palm4MSA<FPP,DEVICE,FPP2>::Palm4MSA(const Faust::ParamsPalm<FPP,DEVICE,FP ...@@ -136,7 +139,10 @@ Faust::Palm4MSA<FPP,DEVICE,FPP2>::Palm4MSA(const Faust::ParamsPalm<FPP,DEVICE,FP
isConstraintSet(false), isConstraintSet(false),
isGlobal(isGlobal_), isGlobal(isGlobal_),
c(FPP(1)/params_palm_.step_size), c(FPP(1)/params_palm_.step_size),
blas_handle(blasHandle) blas_handle(blasHandle),
is_complex(typeid(data.getData()[0]) == typeid(complex<float>) || typeid(data.getData()[0]) == typeid(complex<double>)
),
TorH(is_complex?'H':'T')
{ {
RorL.reserve(const_vec.size()+1); RorL.reserve(const_vec.size()+1);
...@@ -327,16 +333,16 @@ sprintf(nomFichier,"error_1_%d_device.tmp",cmpt);*/ ...@@ -327,16 +333,16 @@ sprintf(nomFichier,"error_1_%d_device.tmp",cmpt);*/
if (!isUpdateWayR2L) if (!isUpdateWayR2L)
{ {
// tmp3 = m_lambda*L'*error (= m_lambda*L' * (m_lambda*L*S*R - data) ) // tmp3 = m_lambda*L'*error (= m_lambda*L' * (m_lambda*L*S*R - data) )
gemm(LorR, error, tmp3, m_lambda,(FPP) 0.0, 'T', 'N', blas_handle); gemm(LorR, error, tmp3, m_lambda,(FPP) 0.0, TorH, 'N', blas_handle);
// grad_over_c = 1/c*tmp3*R' (= 1/c*m_lambda*L' * (m_lambda*L*S*R - data) * R' ) // grad_over_c = 1/c*tmp3*R' (= 1/c*m_lambda*L' * (m_lambda*L*S*R - data) * R' )
gemm(tmp3, RorL[m_indFact], grad_over_c,(FPP) 1.0/c,(FPP) 0.0,'N','T', blas_handle); gemm(tmp3, RorL[m_indFact], grad_over_c,(FPP) 1.0/c,(FPP) 0.0,'N',TorH, blas_handle);
} }
else else
{ {
// tmp3 = m_lambda*L'*error (= m_lambda*L' * (m_lambda*L*S*R - data) ) // tmp3 = m_lambda*L'*error (= m_lambda*L' * (m_lambda*L*S*R - data) )
gemm(RorL[m_indFact], error, tmp3, m_lambda, (FPP) 0.0, 'T', 'N', blas_handle); gemm(RorL[m_indFact], error, tmp3, m_lambda, (FPP) 0.0, TorH, 'N', blas_handle);
// grad_over_c = 1/c*tmp3*R' (= 1/c*m_lambda*L' * (m_lambda*L*S*R - data) * R' ) // grad_over_c = 1/c*tmp3*R' (= 1/c*m_lambda*L' * (m_lambda*L*S*R - data) * R' )
gemm(tmp3, LorR, grad_over_c, (FPP) 1.0/c, (FPP) (FPP) 0.0,'N','T', blas_handle); gemm(tmp3, LorR, grad_over_c, (FPP) 1.0/c, (FPP) (FPP) 0.0,'N',TorH, blas_handle);
} }
} }
else // computing error*R' first, then L'*(error*R') else // computing error*R' first, then L'*(error*R')
...@@ -344,26 +350,20 @@ sprintf(nomFichier,"error_1_%d_device.tmp",cmpt);*/ ...@@ -344,26 +350,20 @@ sprintf(nomFichier,"error_1_%d_device.tmp",cmpt);*/
if (!isUpdateWayR2L) if (!isUpdateWayR2L)
{ {
// tmp3 = m_lambda*error*R' (= m_lambda*(m_lambda*L*S*R - data) * R' ) // tmp3 = m_lambda*error*R' (= m_lambda*(m_lambda*L*S*R - data) * R' )
gemm(error, RorL[m_indFact], tmp3, m_lambda, (FPP) 0.0, 'N', 'T', blas_handle); gemm(error, RorL[m_indFact], tmp3, m_lambda, (FPP) 0.0, 'N', TorH, blas_handle);
// grad_over_c = 1/c*L'*tmp3 (= 1/c*L' * m_lambda*(m_lambda*L*S*R - data) * R' ) // grad_over_c = 1/c*L'*tmp3 (= 1/c*L' * m_lambda*(m_lambda*L*S*R - data) * R' )
gemm(LorR, tmp3, grad_over_c,(FPP) 1.0/c, (FPP) 0.0,'T','N', blas_handle); gemm(LorR, tmp3, grad_over_c,(FPP) 1.0/c, (FPP) 0.0,TorH,'N', blas_handle);
} }
else else
{ {
// tmp3 = m_lambda*error*R' (= m_lambda * (m_lambda*L*S*R - data) * R' ) // tmp3 = m_lambda*error*R' (= m_lambda * (m_lambda*L*S*R - data) * R' )
gemm(error, LorR, tmp3, m_lambda, (FPP) 0.0, 'N', 'T', blas_handle); gemm(error, LorR, tmp3, m_lambda, (FPP) 0.0, 'N', TorH, blas_handle);
// grad_over_c = 1/c*L'*tmp3 (= 1/c*L' * m_lambda*(m_lambda*L*S*R - data) * R' ) // grad_over_c = 1/c*L'*tmp3 (= 1/c*L' * m_lambda*(m_lambda*L*S*R - data) * R' )
gemm(RorL[m_indFact], tmp3, grad_over_c, (FPP) 1.0/c, (FPP) 0.0,'T','N', blas_handle); gemm(RorL[m_indFact], tmp3, grad_over_c, (FPP) 1.0/c, (FPP) 0.0,TorH,'N', blas_handle);
} }
} }
//TODO: avoid type checking by adding another template function (for complex and real types) or function pointer
if(typeid(data.getData()[0]) == typeid(complex<float>) || typeid(data.getData()[0]) == typeid(complex<double>))
{
LorR.conjugate();
RorL[m_indFact].conjugate();
}
isGradComputed = true; isGradComputed = true;
#ifdef __COMPILE_TIMERS__ #ifdef __COMPILE_TIMERS__
...@@ -372,39 +372,6 @@ t_local_compute_grad_over_c.stop(); ...@@ -372,39 +372,6 @@ t_local_compute_grad_over_c.stop();
#endif #endif
} }
template<typename FPP, Device DEVICE, typename FPP2>
void Faust::Palm4MSA<FPP,DEVICE,FPP2>::compute_xt_xhat(Faust::MatDense<FPP,DEVICE>& Xt_Xhat)
{
//TODO: replace by two functions to point to, to avoid the comparison at each iteration
//TODO: replace compute_xt_xhat by a function pointer
if(typeid(FPP) == typeid(complex<double>) || typeid(FPP) == typeid(complex<float>)){
MatDense<FPP,DEVICE> data_cpy = data;
data_cpy.conjugate(false);
gemm(data_cpy, LorR, Xt_Xhat, (FPP) 1.0, (FPP) 0.0, 'T','N', blas_handle);
Xt_Xhat.real();
}
else {
gemm(data, LorR, Xt_Xhat, (FPP) 1.0, (FPP) 0.0, 'T','N', blas_handle);
}
}
template<typename FPP, Device DEVICE, typename FPP2>
void Faust::Palm4MSA<FPP,DEVICE,FPP2>::compute_xhatt_xhat(Faust::MatDense<FPP,DEVICE>& Xhatt_Xhat) {
//TODO: replace by two functions to point to, to avoid the comparison at each iteration
//TODO: replace compute_xhatt_xhat by a function pointer
if(typeid(FPP) == typeid(complex<double>) || typeid(FPP) == typeid(complex<float>)){
Faust::MatDense<FPP,DEVICE> tmp_LoR = LorR;
tmp_LoR.conjugate(false);
gemm(tmp_LoR, LorR, Xhatt_Xhat, (FPP) 1.0, (FPP) 0.0, 'T','N',blas_handle);
// gemm(LorR, LorR, Xhatt_Xhat, (FPP) 1.0, (FPP) 0.0, 'T','N',blas_handle);
Xhatt_Xhat.real();
}
else {
gemm(LorR, LorR, Xhatt_Xhat, (FPP) 1.0, (FPP) 0.0, 'T','N',blas_handle);
}
}
template<typename FPP,Device DEVICE,typename FPP2> template<typename FPP,Device DEVICE,typename FPP2>
void Faust::Palm4MSA<FPP,DEVICE,FPP2>::compute_lambda() void Faust::Palm4MSA<FPP,DEVICE,FPP2>::compute_lambda()
{ {
...@@ -421,10 +388,10 @@ t_local_compute_lambda.start(); ...@@ -421,10 +388,10 @@ t_local_compute_lambda.start();
// As LorR has also been updated at the end of the last iteration over the facts, LorR matches X_hat, which the product of all factors, including the last one. // As LorR has also been updated at the end of the last iteration over the facts, LorR matches X_hat, which the product of all factors, including the last one.
// Xt_Xhat = data'*X_hat // Xt_Xhat = data'*X_hat
Faust::MatDense<FPP,DEVICE> Xt_Xhat; Faust::MatDense<FPP,DEVICE> Xt_Xhat;
this->compute_xt_xhat(Xt_Xhat); gemm(data, LorR, Xt_Xhat, (FPP) 1.0, (FPP) 0.0, TorH,'N', blas_handle);
// Xhatt_Xhat = X_hat'*X_hat // Xhatt_Xhat = X_hat'*X_hat
Faust::MatDense<FPP,DEVICE> Xhatt_Xhat; Faust::MatDense<FPP,DEVICE> Xhatt_Xhat;
this->compute_xhatt_xhat(Xhatt_Xhat); gemm(LorR, LorR, Xhatt_Xhat, (FPP) 1.0, (FPP) 0.0, TorH,'N',blas_handle);
FPP Xhatt_Xhat_tr = (FPP) Xhatt_Xhat.trace(); FPP Xhatt_Xhat_tr = (FPP) Xhatt_Xhat.trace();
......
...@@ -93,27 +93,24 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_grad_over_c() ...@@ -93,27 +93,24 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_grad_over_c()
//TODO: optimize by determining the best product order regarding computation time //TODO: optimize by determining the best product order regarding computation time
multiply(tmp2, D, tmp1, this->blas_handle); multiply(tmp2, D, tmp1, this->blas_handle);
// this->error = lambda*tmp1*lambda*tmp2'-data // this->error is data before this call // this->error = lambda*tmp1*lambda*tmp2'-data // this->error is data before this call
gemm(tmp1, tmp2, this->error, this->m_lambda*this->m_lambda, (FPP)-1.0, 'N', 'T', this->blas_handle); gemm(tmp1, tmp2, this->error, this->m_lambda*this->m_lambda, (FPP)-1.0, 'N', this->TorH, this->blas_handle);
//false is for disabling evaluation (because the transpose does it later)
this->LorR.conjugate(false);
this->RorL[this->m_indFact].conjugate(false);
if (idx==0 || idx==2) // computing L'*this->error first, then (L'*this->error)*R' if (idx==0 || idx==2) // computing L'*this->error first, then (L'*this->error)*R'
{ {
if (!this->isUpdateWayR2L) if (!this->isUpdateWayR2L)
{ {
// tmp3 = this->m_lambda*L'*this->error (= this->m_lambda*L' * (this->m_lambda*L*this->S*R - data) ) // tmp3 = this->m_lambda*L'*this->error (= this->m_lambda*L' * (this->m_lambda*L*this->S*R - data) )
gemm(this->LorR, this->error, tmp3, this->m_lambda,(FPP) 0.0, 'T', 'N', this->blas_handle); gemm(this->LorR, this->error, tmp3, this->m_lambda,(FPP) 0.0, this->TorH, 'N', this->blas_handle);
// tmp2 = lambda*L*this->S*R*D*R' // tmp2 = lambda*L*this->S*R*D*R'
gemm(tmp1, this->RorL[this->m_indFact], tmp2, this->m_lambda, (FPP) 0, 'N', 'T', this->blas_handle); gemm(tmp1, this->RorL[this->m_indFact], tmp2, this->m_lambda, (FPP) 0, 'N', this->TorH, this->blas_handle);
} }
else else
{ {
// tmp3 = this->m_lambda*L'*this->error (= this->m_lambda*L' * (this->m_lambda*L*this->S*R - data) ) // tmp3 = this->m_lambda*L'*this->error (= this->m_lambda*L' * (this->m_lambda*L*this->S*R - data) )
gemm(this->RorL[this->m_indFact], this->error, tmp3, this->m_lambda, (FPP) 0.0, 'T', 'N', this->blas_handle); gemm(this->RorL[this->m_indFact], this->error, tmp3, this->m_lambda, (FPP) 0.0, this->TorH, 'N', this->blas_handle);
// tmp2 = lambda*L*this->S*R*D*R' // tmp2 = lambda*L*this->S*R*D*R'
gemm(tmp1, this->LorR, tmp2, this->m_lambda, (FPP) 0, 'N', 'T', this->blas_handle); gemm(tmp1, this->LorR, tmp2, this->m_lambda, (FPP) 0, 'N', this->TorH, this->blas_handle);
} }
// grad_over_c = 1/this->c*tmp3*tmp2 // grad_over_c = 1/this->c*tmp3*tmp2
gemm(tmp3, tmp2, this->grad_over_c, (FPP) 1.0/this->c, (FPP) (FPP) 0.0,'N','N', this->blas_handle); gemm(tmp3, tmp2, this->grad_over_c, (FPP) 1.0/this->c, (FPP) (FPP) 0.0,'N','N', this->blas_handle);
...@@ -124,30 +121,25 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_grad_over_c() ...@@ -124,30 +121,25 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_grad_over_c()
if (!this->isUpdateWayR2L) if (!this->isUpdateWayR2L)
{ {
// tmp2 = lambda*tmp1*R' = lambda*LSRD*R' // tmp2 = lambda*tmp1*R' = lambda*LSRD*R'
gemm(tmp1, this->RorL[this->m_indFact], tmp2, (FPP) this->m_lambda, (FPP) 0, 'N', 'T', this->blas_handle); gemm(tmp1, this->RorL[this->m_indFact], tmp2, (FPP) this->m_lambda, (FPP) 0, 'N', this->TorH, this->blas_handle);
// tmp3 = this->m_lambda*this->error*tmp2 // tmp3 = this->m_lambda*this->error*tmp2
gemm(this->error, tmp2, tmp3, this->m_lambda, (FPP) 0.0, 'N', 'N', this->blas_handle); gemm(this->error, tmp2, tmp3, this->m_lambda, (FPP) 0.0, 'N', 'N', this->blas_handle);
// grad_over_c = 1/this->c*L'*tmp3 // grad_over_c = 1/this->c*L'*tmp3
gemm(this->LorR, tmp3, this->grad_over_c,(FPP) 1.0/this->c, (FPP) 0.0,'T','N', this->blas_handle); gemm(this->LorR, tmp3, this->grad_over_c,(FPP) 1.0/this->c, (FPP) 0.0,this->TorH,'N', this->blas_handle);
} }
else else
{ {
// tmp2 = lambda*tmp1*R' = lambda*LSRD*R' // tmp2 = lambda*tmp1*R' = lambda*LSRD*R'
gemm(tmp1, this->LorR, tmp2, (FPP) this->m_lambda, (FPP) 0, 'N', 'T', this->blas_handle); gemm(tmp1, this->LorR, tmp2, (FPP) this->m_lambda, (FPP) 0, 'N', this->TorH, this->blas_handle);
// tmp3 = this->m_lambda*this->error*tmp2 // tmp3 = this->m_lambda*this->error*tmp2
gemm(this->error, tmp2, tmp3, this->m_lambda, (FPP) 0.0, 'N', 'N', this->blas_handle); gemm(this->error, tmp2, tmp3, this->m_lambda, (FPP) 0.0, 'N', 'N', this->blas_handle);
// grad_over_c = 1/this->c*L'*tmp3 // grad_over_c = 1/this->c*L'*tmp3
gemm(this->RorL[this->m_indFact], tmp3, this->grad_over_c, (FPP) 1.0/this->c, (FPP) 0.0,'T','N', this->blas_handle); gemm(this->RorL[this->m_indFact], tmp3, this->grad_over_c, (FPP) 1.0/this->c, (FPP) 0.0,this->TorH,'N', this->blas_handle);
} }
} }
//TODO: avoid type checking by adding another template function (for complex and real types) or function pointer
if(typeid(this->data.getData()[0]) == typeid(complex<float>) || typeid(this->data.getData()[0]) == typeid(complex<double>))
{
this->LorR.conjugate();
this->RorL[this->m_indFact].conjugate();
}
this->isGradComputed = true; this->isGradComputed = true;
#ifdef __COMPILE_TIMERS__ #ifdef __COMPILE_TIMERS__
...@@ -165,7 +157,7 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_lambda() ...@@ -165,7 +157,7 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_lambda()
// Xhat = LorR*D*LorR' // LorR equals the prod of all factors after their update iterations (in loop of next_step()) // Xhat = LorR*D*LorR' // LorR equals the prod of all factors after their update iterations (in loop of next_step())
MatDense<FPP,Cpu> tmp; MatDense<FPP,Cpu> tmp;
// tmp = D*LorR' // tmp = D*LorR'
gemm(this->D, this->LorR, tmp, (FPP) 1.0, (FPP) 0.0, 'N', 'T', this->blas_handle); gemm(this->D, this->LorR, tmp, (FPP) 1.0, (FPP) 0.0, 'N', this->TorH, this->blas_handle);
// LorR = LorR*tmp // LorR = LorR*tmp
gemm(this->LorR, tmp, D_grad_over_c, (FPP) 1.0, (FPP) 0.0, 'N', 'N', this->blas_handle); gemm(this->LorR, tmp, D_grad_over_c, (FPP) 1.0, (FPP) 0.0, 'N', 'N', this->blas_handle);
tmp = this->LorR; tmp = this->LorR;
...@@ -216,7 +208,7 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_D_grad_over_c() ...@@ -216,7 +208,7 @@ void Palm4MSAFFT<FPP,DEVICE,FPP2>::compute_D_grad_over_c()
D_grad_over_c -= this->data; D_grad_over_c -= this->data;
//TODO: opt. by determining best order of product //TODO: opt. by determining best order of product
// tmp = LorR'*(LorR*D*LorR' - X) // tmp = LorR'*(LorR*D*LorR' - X)
gemm(this->LorR, D_grad_over_c, tmp, (FPP) 1., (FPP) 0., 'T', 'N', this->blas_handle); gemm(this->LorR, D_grad_over_c, tmp, (FPP) 1., (FPP) 0., this->TorH, 'N', this->blas_handle);
// D_grad_over_c = LorR'*(LorR*D*LorR' - X)*LorR // D_grad_over_c = LorR'*(LorR*D*LorR' - X)*LorR
gemm(tmp, this->LorR, D_grad_over_c, (FPP) 1., (FPP) 0., 'N', 'N', this->blas_handle); gemm(tmp, this->LorR, D_grad_over_c, (FPP) 1., (FPP) 0., 'N', 'N', this->blas_handle);
} }
......
...@@ -322,6 +322,11 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F ...@@ -322,6 +322,11 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F
#ifdef __COMPILE_TIMERS__ #ifdef __COMPILE_TIMERS__
A.t_gemm.start(); A.t_gemm.start();
#endif #endif
#ifdef __GEMM_WITH_OPENBLAS__
if(typeA == 'H' || typeB == 'H')
handleError("linear_algebra", " gemm: Hermitian matrix is not yet handled with BLAS.")
#endif
faust_unsigned_int nbRowOpA,nbRowOpB,nbColOpA,nbColOpB; faust_unsigned_int nbRowOpA,nbRowOpB,nbColOpA,nbColOpB;
if ( ((&(C.mat)) == (&(A.mat))) || ((&(C.mat)) == (&(B.mat))) ) if ( ((&(C.mat)) == (&(A.mat))) || ((&(C.mat)) == (&(B.mat))) )
...@@ -330,7 +335,7 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F ...@@ -330,7 +335,7 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F
handleError("linear_algebra", " gemm : C is the same object as A or B"); handleError("linear_algebra", " gemm : C is the same object as A or B");
} }
if (typeA == 'T') if (typeA == 'T' || typeA == 'H')
{ {
nbRowOpA = A.getNbCol(); nbRowOpA = A.getNbCol();
nbColOpA = A.getNbRow(); nbColOpA = A.getNbRow();
...@@ -341,7 +346,7 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F ...@@ -341,7 +346,7 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F
} }
if (typeB == 'T') if (typeB == 'T' || typeB == 'H')
{ {
nbRowOpB = B.getNbCol(); nbRowOpB = B.getNbCol();
nbColOpB = B.getNbRow(); nbColOpB = B.getNbRow();
...@@ -404,6 +409,11 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F ...@@ -404,6 +409,11 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F
C=B; C=B;
if(typeB == 'T') if(typeB == 'T')
C.transpose(); C.transpose();
else if(typeB == 'H')
{
C.conjugate(false);
C.transpose();
}
if(alpha!=FPP(1.0)) if(alpha!=FPP(1.0))
C*= alpha; C*= alpha;
C.isZeros = false; C.isZeros = false;
...@@ -418,6 +428,11 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F ...@@ -418,6 +428,11 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F
C=A; C=A;
if(typeA == 'T') if(typeA == 'T')
C.transpose(); C.transpose();
else if(typeA == 'H')
{
C.conjugate(false);
C.transpose();
}
if(alpha!=FPP(1.0)) if(alpha!=FPP(1.0))
C*= alpha; C*= alpha;
C.isZeros = false; C.isZeros = false;
...@@ -435,14 +450,27 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F ...@@ -435,14 +450,27 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F
{ {
if (typeB == 'N') if (typeB == 'N')
C.mat.noalias() = alpha * A.mat * B.mat; C.mat.noalias() = alpha * A.mat * B.mat;
else else if(typeB == 'T')
C.mat.noalias() = alpha * A.mat * B.mat.transpose(); C.mat.noalias() = alpha * A.mat * B.mat.transpose();
}else else // typeB == 'H' //TODO: check validity of typeA and typeB
C.mat.noalias() = alpha * A.mat * B.mat.transpose().conjugate();
}else if(typeA == 'T')
{ {
if (typeB == 'N') if (typeB == 'N')
C.mat.noalias() = alpha * A.mat.transpose() * B.mat; C.mat.noalias() = alpha * A.mat.transpose() * B.mat;
else else if(typeB == 'T')
C.mat.noalias() = alpha * A.mat.transpose() * B.mat.transpose(); C.mat.noalias() = alpha * A.mat.transpose() * B.mat.transpose();
else // typeB == 'H' //TODO: check validity of typeA and typeB
C.mat.noalias() = alpha * A.mat.transpose() * B.mat.transpose().conjugate();
}
else //if(typeA == 'H')
{
if (typeB == 'N')
C.mat.noalias() = alpha * A.mat.transpose().conjugate() * B.mat;
else if(typeB == 'T')
C.mat.noalias() = alpha * A.mat.transpose().conjugate() * B.mat.transpose();
else // typeB == 'H' //TODO: check validity of typeA and typeB
C.mat.noalias() = alpha * A.mat.transpose().conjugate() * B.mat.transpose().conjugate();
} }
#else #else
FPP beta = FPP(0.0); FPP beta = FPP(0.0);
...@@ -478,6 +506,11 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F ...@@ -478,6 +506,11 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F
Faust::MatDense<FPP,Cpu> B_tmp(B); Faust::MatDense<FPP,Cpu> B_tmp(B);
if(typeB == 'T') if(typeB == 'T')
B_tmp.transpose(); B_tmp.transpose();
else if(typeB == 'H')
{
B_tmp.conjugate(false);
B_tmp.transpose();
}
if(alpha != FPP(1.0)) if(alpha != FPP(1.0))
B_tmp *= alpha; B_tmp *= alpha;
C += B_tmp; C += B_tmp;
...@@ -504,6 +537,11 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F ...@@ -504,6 +537,11 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F
Faust::MatDense<FPP,Cpu> A_tmp(A); Faust::MatDense<FPP,Cpu> A_tmp(A);
if(typeA == 'T') if(typeA == 'T')
A_tmp.transpose(); A_tmp.transpose();
else if(typeA == 'H')
{
A_tmp.conjugate(false);
A_tmp.transpose();
}
if(alpha != FPP(1.0)) if(alpha != FPP(1.0))
A_tmp *= alpha; A_tmp *= alpha;
C += A_tmp; C += A_tmp;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment