Mentions légales du service

Skip to content
Snippets Groups Projects
Commit ed06bfc5 authored by hhakim's avatar hhakim
Browse files

Enable StoppingCriterion epsilon error change to 1e-6 by default in PALM4MSA 2020 implementation.

Waiting for addition of the epsilon on error as a configurable parameter.
This code undoes f41949e6 for refactoring (in StoppingCriterion) and optimization reasons (avoiding recomputing the error several times in the same iteration).
parent 16956364
Branches
Tags
No related merge requests found
......@@ -49,7 +49,7 @@ namespace Faust
/** lambda output, intialized from outside */
Real<FPP>& lambda,
//const unsigned int nites,
const StoppingCriterion<Real<FPP>>& sc,
StoppingCriterion<Real<FPP>>& sc,
const bool is_update_way_R2L=false,
const FactorsFormat factors_format=AllDynamic,
const bool packing_RL=true,
......@@ -95,8 +95,6 @@ namespace Faust
template<typename FPP, FDevice DEVICE>
Real<FPP> calc_rel_err(const TransformHelper<FPP,DEVICE>& S, const MatDense<FPP,DEVICE> &A, const Real<FPP> &lambda=1, const Real<FPP>* A_norm=nullptr);
template<typename FPP, FDevice DEVICE> Real<FPP> compute_rel_change(const TransformHelper<FPP,DEVICE>& previousS, const Real<FPP>& previouslambda , const TransformHelper<FPP,DEVICE>& currentS, const Real<FPP>& currentlambda);
/**
* \brief This function performs the (scaling factor) lambda update of the PALM4MSA algorithm (palm4msa2).
*
......
......@@ -43,23 +43,13 @@ Real<FPP> Faust::calc_rel_err(const TransformHelper<FPP,DEVICE>& S, const MatDen
return err.norm() / *A_norm;
}
template<typename FPP, FDevice DEVICE> Real<FPP> Faust::compute_rel_change(const TransformHelper<FPP,DEVICE>& previousS, const Real<FPP>& previouslambda, const TransformHelper<FPP,DEVICE>& currentS, const Real<FPP>& currentlambda)
{
MatDense<FPP, DEVICE> m1 = const_cast<TransformHelper<FPP, DEVICE>&>(previousS).get_product();
m1 *= FPP(previouslambda);
MatDense<FPP, DEVICE> dm = const_cast<TransformHelper<FPP, DEVICE>&>(currentS).get_product();
dm *= FPP(currentlambda);
dm -= m1;
return dm.norm() / m1.norm();
}
template <typename FPP, FDevice DEVICE>
void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
std::vector<Faust::ConstraintGeneric*> & constraints,
Faust::TransformHelper<FPP,DEVICE>& S,
Real<FPP>& lambda, //TODO: FPP lambda ? is it useful to have a complex lamdba ?
//const unsigned int nites,
const StoppingCriterion<Real<FPP>>& sc,
StoppingCriterion<Real<FPP>>& sc,
const bool is_update_way_R2L,
const FactorsFormat factors_format,
const bool packing_RL,
......@@ -245,9 +235,8 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
is_last_fac_updated = [&f_id, &nfacts]() {return f_id == nfacts-1;};
}
// to compute relative change
Faust::TransformHelper<FPP,DEVICE> previousS;
Real<FPP> previouslambda;
// to stop on small error change between two iterations
sc.setCriterionEpsErr(1e-6);
while(sc.do_continue(i, error))
{
......@@ -255,8 +244,6 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
// std::cout << "nfacts:" << nfacts << std::endl;
init_ite();
if(i == 0)
previousS = S;
while(updating_facs())
{
// std::cout << "#f_id: " << f_id << std::endl;
......@@ -294,20 +281,6 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
std::cout << " lambda=" << lambda << std::endl;
}
}
// check if solution changes, if not stop while
if(i == 0){
previousS = S;
previouslambda = lambda;
}else{
auto rel_change = compute_rel_change(previousS, previouslambda, S, lambda);
std::cout << "relative change: " << rel_change << std::endl;
if(rel_change < 1e-6){
cout << "relative change is < 1e-6, stop iterations at " << i << std::endl;
break;
}
previousS = S;
previouslambda = lambda;
}
i++;
}
S.update_total_nnz();
......@@ -341,7 +314,7 @@ void Faust::compute_n_apply_grad1(const int f_id, const Faust::MatDense<FPP,DEVI
// _LSR.get_product(tmp);
tmp *= FPP(lambda);
tmp -= A;
if(sc.isCriterionErr())
if(sc.isCriterionErr() || sc.isCriterionEpsErr())
error = tmp.norm();
FPP alpha_R = 1, alpha_L = 1, beta_R = 0, beta_L = 0; //decl in parent scope
auto pR_sz = pR[f_id]->size();
......@@ -430,7 +403,7 @@ void Faust::compute_n_apply_grad2(const int f_id, const Faust::MatDense<FPP,DEVI
// compute error = m_lambda*L*S*R-data
facts = { _L, &D, _R };
mul_3_facts(facts, tmp, (FPP) lambda, (FPP) -1.0);
if(sc.isCriterionErr())
if(sc.isCriterionErr() || sc.isCriterionEpsErr())
error = tmp.norm();
// compute m_lambda/c * L'*error*R'
facts = { _L, &tmp, _R };
......@@ -441,7 +414,7 @@ void Faust::compute_n_apply_grad2(const int f_id, const Faust::MatDense<FPP,DEVI
// compute error = m_lambda*L*S*R-data
facts = { &D, _R };
mul_3_facts(facts, tmp, (FPP) lambda, (FPP) -1.0);
if(sc.isCriterionErr())
if(sc.isCriterionErr() || sc.isCriterionEpsErr())
error = tmp.norm();
// compute m_lambda/c * L'*error*R'
facts = { &tmp, _R };
......@@ -452,7 +425,7 @@ void Faust::compute_n_apply_grad2(const int f_id, const Faust::MatDense<FPP,DEVI
// compute error = m_lambda*L*S*R-data
facts = { _L, &D};
mul_3_facts(facts, tmp, (FPP) lambda, (FPP) -1.0);
if(sc.isCriterionErr())
if(sc.isCriterionErr() || sc.isCriterionEpsErr())
error = tmp.norm();
// compute m_lambda/c * L'*error*R'
facts = { _L, &tmp};
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment