Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 8038559a authored by hhakim's avatar hhakim
Browse files

Optimize update_lambda computing in C++ PALM4MSA 2020 by decreasing the number...

Optimize update_lambda computing in C++ PALM4MSA 2020 by decreasing the number of matrix products needed to compute S full array.
parent f02508e5
Branches
Tags
No related merge requests found
......@@ -73,7 +73,7 @@ namespace Faust
mhtp_params.constant_step_size, mhtp_params.step_size,
sc, error, factors_format, prod_mod, c, lambda);
if(mhtp_params.updating_lambda)
update_lambda(S, A_H, lambda);
update_lambda(S, pL, pR, A_H, lambda);
j++;
}
if(is_verbose)
......
......@@ -119,7 +119,7 @@ namespace Faust
* \param lambda: the output of the lambda computed by the function.
*/
template<typename FPP, FDevice DEVICE>
void update_lambda(Faust::TransformHelper<FPP,DEVICE>& S, const MatDense<FPP, DEVICE> A_H, Real<FPP>& lambda);
void update_lambda(Faust::TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const MatDense<FPP, DEVICE> &A_H, Real<FPP>& lambda);
template<typename FPP, FDevice DEVICE>
void update_fact(
......
......@@ -160,7 +160,7 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
next_fid(); // f_id updated to iteration factor index (pL or pR too)
}
//update lambda
update_lambda(S, A_H, lambda);
update_lambda(S, pL, pR, A_H, lambda);
if(is_verbose)
{
set_calc_err_ite_period(); //macro setting the variable ite_period
......@@ -319,14 +319,41 @@ void Faust::compute_n_apply_grad2(const int f_id, const Faust::MatDense<FPP,DEVI
}
template<typename FPP, FDevice DEVICE>
void Faust::update_lambda(Faust::TransformHelper<FPP,DEVICE>& S, const MatDense<FPP, DEVICE> A_H, Real<FPP>& lambda)
void Faust::update_lambda(Faust::TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const MatDense<FPP, DEVICE> &A_H, Real<FPP>& lambda)
{
Faust::MatDense<FPP,DEVICE> A_H_S = S.multiply(A_H);
Real<FPP> trr = std::real(A_H_S.trace());
Real<FPP> n = S.normFro();
if(std::numeric_limits<Real<FPP>>::epsilon() >= n)
throw std::runtime_error("Error in update_lambda: Faust Frobenius norm is zero, can't compute lambda.");
lambda = trr/(n*n);
Faust::MatDense<FPP,DEVICE> A_H_S;
MatDense<FPP, DEVICE> S_mat;
FPP tr; // A_H_S trace
Real<FPP> nS; // S fro. norm
auto n = S.size();
bool packing_RL = (pR[0] == nullptr || pR[0]->size() == 1) && (pL[n-1] == nullptr || pL[n-1]->size() == 1);
// compute S full matrix
if(packing_RL)
{
if(pR[0] == nullptr || pL[n-1] == nullptr)
throw std::logic_error("update_lambda: pR and pL weren't properly initialized.");
// optimize S product by re-using R[0] or L[0] (S == S[0]*R[0] or S == L[n-1]*S[n-1])
if(S.get_gen_fact(0)->getNbCol() * pR[0]->getNbRow() < pL[n-1]->getNbCol() * S.get_gen_fact(n-1)->getNbRow())
{
auto S0 = { S.get_gen_fact(0) };
TransformHelper<FPP, DEVICE> _S(S0, *pR[0]);
_S.get_product(S_mat);
}
else
{
auto Sn = { S.get_gen_fact(n-1) };
TransformHelper<FPP, DEVICE> _S(*pL[n-1], Sn);
_S.get_product(S_mat);
}
}
else
S.get_product(S_mat);
gemm(A_H, S_mat, A_H_S, (FPP) 1.0, (FPP) 0.0, 'N', 'N');
tr = A_H_S.trace();
nS = S_mat.norm();
if(std::numeric_limits<Real<FPP>>::epsilon() >= nS)
throw std::runtime_error("Error in update_lambda: S Frobenius norm is zero, can't compute lambda.");
lambda = std::real(tr)/(nS*nS);
}
template<typename FPP, FDevice DEVICE>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment