Mentions légales du service

Skip to content
Snippets Groups Projects
Commit a7917ace authored by hhakim's avatar hhakim
Browse files

Optimize Faust::GivensFGFT::update_L().

- Big speedup (now doing better than original matlab script).
- update_L() becomes virtual because GivensFGFTParallel needs to keep the old def (several pivots per Givens matrix, so it can't use same opt as GivensFGFT).

Not related minor change:
- Fix missing trailing ; in faust_linear_algebra.hpp.

[skip ci]
parent 97c9d8b3
No related branches found
No related tags found
No related merge requests found
...@@ -153,7 +153,7 @@ namespace Faust { ...@@ -153,7 +153,7 @@ namespace Faust {
* *
* Updates L after Givens factor update for the next iteration. * Updates L after Givens factor update for the next iteration.
*/ */
void update_L(); virtual void update_L();
/** /**
* \brief Algo. step 2.5. * \brief Algo. step 2.5.
......
...@@ -164,6 +164,8 @@ void GivensFGFT<FPP,DEVICE,FPP2>::update_fact() ...@@ -164,6 +164,8 @@ void GivensFGFT<FPP,DEVICE,FPP2>::update_fact()
// facts{j} = S; // facts{j} = S;
// //
int n = Lap.getNbRow(); int n = Lap.getNbRow();
FPP2 c = cos(theta);
FPP2 s = sin(theta);
// forget previous rotation coeffs // forget previous rotation coeffs
// and keep identity part (n first coeffs) // and keep identity part (n first coeffs)
fact_mod_row_ids.resize(n); fact_mod_row_ids.resize(n);
...@@ -173,19 +175,19 @@ void GivensFGFT<FPP,DEVICE,FPP2>::update_fact() ...@@ -173,19 +175,19 @@ void GivensFGFT<FPP,DEVICE,FPP2>::update_fact()
// 1st one // 1st one
fact_mod_row_ids.push_back(p); fact_mod_row_ids.push_back(p);
fact_mod_col_ids.push_back(p); fact_mod_col_ids.push_back(p);
fact_mod_values.push_back(cos(theta)); fact_mod_values.push_back(c);
// 2nd // 2nd
fact_mod_row_ids.push_back(p); fact_mod_row_ids.push_back(p);
fact_mod_col_ids.push_back(q); fact_mod_col_ids.push_back(q);
fact_mod_values.push_back(-sin(theta)); fact_mod_values.push_back(-s);
// 3rd // 3rd
fact_mod_row_ids.push_back(q); fact_mod_row_ids.push_back(q);
fact_mod_col_ids.push_back(p); fact_mod_col_ids.push_back(p);
fact_mod_values.push_back(sin(theta)); fact_mod_values.push_back(s);
// 4th // 4th
fact_mod_row_ids.push_back(q); fact_mod_row_ids.push_back(q);
fact_mod_col_ids.push_back(q); fact_mod_col_ids.push_back(q);
fact_mod_values.push_back(cos(theta)); fact_mod_values.push_back(c);
facts[ite] = MatSparse<FPP,DEVICE>(fact_mod_row_ids, fact_mod_col_ids, fact_mod_values, n, n); facts[ite] = MatSparse<FPP,DEVICE>(fact_mod_row_ids, fact_mod_col_ids, fact_mod_values, n, n);
#ifdef DEBUG_GIVENS #ifdef DEBUG_GIVENS
cout << "GivensFGFT::update_fact() ite: " << ite << " fact norm: " << facts[ite].norm() << endl; cout << "GivensFGFT::update_fact() ite: " << ite << " fact norm: " << facts[ite].norm() << endl;
...@@ -200,8 +202,52 @@ void GivensFGFT<FPP,DEVICE,FPP2>::update_L() ...@@ -200,8 +202,52 @@ void GivensFGFT<FPP,DEVICE,FPP2>::update_L()
#ifdef DEBUG_GIVENS #ifdef DEBUG_GIVENS
cout << "L(p,q) before update_L():" << L(p,q) << endl; cout << "L(p,q) before update_L():" << L(p,q) << endl;
#endif #endif
#define OPT_UPDATE_L
#ifndef OPT_UPDATE_L
facts[ite].multiply(L, 'T'); facts[ite].multiply(L, 'T');
L.multiplyRight(MatDense<FPP,Cpu>(facts[ite])); L.multiplyRight(facts[ite]);
#else
Vect<FPP,DEVICE> L_vec_p = L.get_row(p), L_vec_q = L.get_row(q);
Vect<FPP,DEVICE> tmp, tmp2;
FPP2 c = *(fact_mod_values.end()-1); // cos(theta)
FPP2 s = *(fact_mod_values.end()-2); // sin(theta)
#define copy_vec2Lrow(vec,rowi) \
for(int i=0;i<L.getNbCol();i++) L.getData()[L.getNbRow()*i+rowi] = tmp[i]
/*========== L = S'*L */
// L(p,:) = c*L(p,:) + s*L(q,:)
tmp = L_vec_p;
tmp *= c;
tmp2 = L_vec_q;
tmp2 *= s;
tmp += tmp2;
copy_vec2Lrow(tmp,p);
// L(q,:) = -s*L(p,:) + c*L(q,:)
tmp = L_vec_p;
tmp *= -s;
tmp2 = L_vec_q;
tmp2 *= c;
tmp += tmp2;
copy_vec2Lrow(tmp, q);
/*========== L *= S */
L_vec_p = L.get_col(p), L_vec_q = L.get_col(q);
// L(:,p) = c*L(:,p) + s*L(:,q)
tmp = L_vec_p;
tmp *= c;
tmp2 = L_vec_q;
tmp2 *= s;
tmp += tmp2;
memcpy(L.getData()+L.getNbRow()*p, tmp.getData(), sizeof(FPP)*L.getNbRow());
// L(:,q) = -s*L(:,p) + c*L(:,q)
tmp = L_vec_p;
tmp *= -s;
tmp2 = L_vec_q;
tmp2 *= c;
tmp += tmp2;
memcpy(L.getData()+L.getNbRow()*q, tmp.getData(), sizeof(FPP)*L.getNbRow());
#endif
#ifdef DEBUG_GIVENS #ifdef DEBUG_GIVENS
cout << "L(p,q) after update_L():" << L(p,q) << endl; cout << "L(p,q) after update_L():" << L(p,q) << endl;
#endif #endif
......
...@@ -56,6 +56,7 @@ namespace Faust { ...@@ -56,6 +56,7 @@ namespace Faust {
* Computes the coefficients of the last selected rotation matrix to be put later in current iteration factor. * Computes the coefficients of the last selected rotation matrix to be put later in current iteration factor.
*/ */
void update_fact(); void update_fact();
void update_L();
/** /**
* Constructs the current factor after computing all the coefficients (of rotation matrices) in temporary buffers (see update_fact()). * Constructs the current factor after computing all the coefficients (of rotation matrices) in temporary buffers (see update_fact()).
*/ */
......
...@@ -146,6 +146,16 @@ void GivensFGFTParallel<FPP,DEVICE,FPP2>::update_fact() ...@@ -146,6 +146,16 @@ void GivensFGFTParallel<FPP,DEVICE,FPP2>::update_fact()
this->fact_mod_values.push_back(cos(this->theta)); this->fact_mod_values.push_back(cos(this->theta));
} }
template<typename FPP, Device DEVICE, typename FPP2>
void GivensFGFTParallel<FPP,DEVICE,FPP2>::update_L()
{
// L = S'*L*S
#ifdef DEBUG_GIVENS
cout << "L(p,q) before update_L():" << L(p,q) << endl;
#endif
this->facts[this->ite].multiply(this->L, 'T');
this->L.multiplyRight(this->facts[this->ite]);
}
template<typename FPP, Device DEVICE, typename FPP2> template<typename FPP, Device DEVICE, typename FPP2>
void GivensFGFTParallel<FPP,DEVICE,FPP2>::finish_fact() void GivensFGFTParallel<FPP,DEVICE,FPP2>::finish_fact()
......
...@@ -325,7 +325,7 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F ...@@ -325,7 +325,7 @@ void Faust::gemm_core(const Faust::MatDense<FPP,Cpu> & A,const Faust::MatDense<F
#ifdef __GEMM_WITH_OPENBLAS__ #ifdef __GEMM_WITH_OPENBLAS__
if(typeA == 'H' || typeB == 'H') if(typeA == 'H' || typeB == 'H')
handleError("linear_algebra", " gemm: Hermitian matrix is not yet handled with BLAS.") handleError("linear_algebra", " gemm: Hermitian matrix is not yet handled with BLAS.");
#endif #endif
faust_unsigned_int nbRowOpA,nbRowOpB,nbColOpA,nbColOpB; faust_unsigned_int nbRowOpA,nbRowOpB,nbColOpA,nbColOpB;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment