Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 53a9aa39 authored by hhakim's avatar hhakim
Browse files

Update/Refactor MHTP algorithm in its own module.

- Enabling the use of updating_lambda, palm4msa_period MTHP parameters.
- Refactor update_fac(t) and update_lambda in separate functions of palm4msa2.hpp because they are used both by PALM4MSA and MHTP.
parent 7900acf5
Branches
Tags
No related merge requests found
...@@ -19,6 +19,36 @@ namespace Faust ...@@ -19,6 +19,36 @@ namespace Faust
MHTPParams(); MHTPParams();
std::string to_string() const; std::string to_string() const;
}; };
/**
* \brief This function performs the Multilinear Hard Tresholding Pursuit) for as number of iterations as defined in mhtp_params.
*
* Reference: https://hal.inria.fr/hal-03132013/document
*/
template<typename FPP, FDevice DEVICE>
void perform_MHTP(
const MHTPParams<FPP>& mhtp_params,
Faust::MatGeneric<FPP,DEVICE>* cur_fac,
int f_id,
const Faust::MatDense<FPP,DEVICE>& A,
const Faust::MatDense<FPP,DEVICE>& A_H,
Faust::TransformHelper<FPP,DEVICE>& S,
std::vector<TransformHelper<FPP,DEVICE>*> &pL,
std::vector<TransformHelper<FPP,DEVICE>*> &pR,
const bool is_verbose,
std::vector<Faust::ConstraintGeneric*> & constraints,
const int norm2_max_iter,
const Real<FPP>& norm2_threshold,
std::chrono::duration<double>& spectral_duration,
std::chrono::duration<double>& fgrad_duration,
const StoppingCriterion<Real<FPP>>& sc,
Real<FPP> &error,
const bool use_csr,
const bool packing_RL,
const int prod_mod,
Real<FPP> &c,
Real<FPP>& lambda);
}; };
#include "faust_MHTP.hpp" #include "faust_MHTP.hpp"
#endif #endif
#include "faust_palm4msa2020.h"
namespace Faust namespace Faust
{ {
template<typename FPP> template<typename FPP>
...@@ -15,9 +16,9 @@ namespace Faust ...@@ -15,9 +16,9 @@ namespace Faust
std::string str = "MHTPParams (START):"; std::string str = "MHTPParams (START):";
str += "\r\n"; str += "\r\n";
str += "StoppingCriterion:"; str += "StoppingCriterion:";
str += "\r\n ==="; str += "\r\n === \r\n";
str += sc_str; str += sc_str;
str += " === \r\n"; str += "\r\n === \r\n";
str += "constant_step_size: "; str += "constant_step_size: ";
str += std::to_string(constant_step_size); str += std::to_string(constant_step_size);
str += "\r\n"; str += "\r\n";
...@@ -33,4 +34,50 @@ namespace Faust ...@@ -33,4 +34,50 @@ namespace Faust
str += "MHTPParams END."; str += "MHTPParams END.";
return str; return str;
} }
template<typename FPP, FDevice DEVICE>
void perform_MHTP(
const MHTPParams<FPP>& mhtp_params,
Faust::MatGeneric<FPP,DEVICE>* cur_fac,
int f_id,
const Faust::MatDense<FPP,DEVICE>& A,
const Faust::MatDense<FPP,DEVICE>& A_H,
Faust::TransformHelper<FPP,DEVICE>& S,
std::vector<TransformHelper<FPP,DEVICE>*> &pL,
std::vector<TransformHelper<FPP,DEVICE>*> &pR,
const bool is_verbose,
std::vector<Faust::ConstraintGeneric*> & constraints,
const int norm2_max_iter,
const Real<FPP>& norm2_threshold,
std::chrono::duration<double>& norm2_duration,
std::chrono::duration<double>& fgrad_duration,
const StoppingCriterion<Real<FPP>>& sc,
Real<FPP> &error,
const bool use_csr,
const bool packing_RL,
const int prod_mod,
Real<FPP> &c,
Real<FPP>& lambda)
{
if(is_verbose)
std::cout << "Starting a MHTP pass ("<< mhtp_params.sc.get_crit() <<" iterations) for factor #" << f_id << std::endl;
int j = 0;
// set the factor to zero
cur_fac->setZeros();
while(mhtp_params.sc.do_continue(j)) // TODO: what about the error stop criterion?
{
update_fact(cur_fac, f_id, A, S, pL, pR,
is_verbose, constraints,
norm2_max_iter, norm2_threshold, norm2_duration, fgrad_duration,
mhtp_params.constant_step_size, mhtp_params.step_size,
sc, error, use_csr, packing_RL, prod_mod, c, lambda);
if(mhtp_params.updating_lambda)
update_lambda(S, A_H, lambda);
j++;
}
if(is_verbose)
std::cout << "The MHTP pass has ended" << std::endl;
}
} }
...@@ -29,6 +29,8 @@ ...@@ -29,6 +29,8 @@
if (nullptr != env_var) \ if (nullptr != env_var) \
ite_period = atoi(env_var); ite_period = atoi(env_var);
#define LIPSCHITZ_MULTIPLICATOR 1.001
namespace Faust namespace Faust
{ {
...@@ -107,6 +109,42 @@ namespace Faust ...@@ -107,6 +109,42 @@ namespace Faust
template<typename FPP, FDevice DEVICE> template<typename FPP, FDevice DEVICE>
Real<FPP> calc_rel_err(const TransformHelper<FPP,DEVICE>& S, const MatDense<FPP,DEVICE> &A, const FPP &lambda=1, const Real<FPP>* A_norm=nullptr); Real<FPP> calc_rel_err(const TransformHelper<FPP,DEVICE>& S, const MatDense<FPP,DEVICE> &A, const FPP &lambda=1, const Real<FPP>* A_norm=nullptr);
/**
* \brief This function performs the (scaling factor) lambda update of the PALM4MSA algorithm (palm4msa2).
*
* \param S: the Faust being refined by the PALM4MSA algorithm (palm4msa2).
* \param A_H: the transconjugate of the matrix A for which S is an approximate.
* \param lambda: the output of the lambda computed by the function.
*/
template<typename FPP, FDevice DEVICE>
void update_lambda(Faust::TransformHelper<FPP,DEVICE>& S, const MatDense<FPP, DEVICE> A_H, Real<FPP>& lambda);
template<typename FPP, FDevice DEVICE>
void update_fact(
Faust::MatGeneric<FPP,DEVICE>* cur_fac,
int f_id,
const Faust::MatDense<FPP,DEVICE>& A,
Faust::TransformHelper<FPP,DEVICE>& S,
std::vector<TransformHelper<FPP,DEVICE>*> &pL,
std::vector<TransformHelper<FPP,DEVICE>*> &pR,
const bool is_verbose,
std::vector<Faust::ConstraintGeneric*> & constraints,
const int norm2_max_iter,
const Real<FPP>& norm2_threshold,
std::chrono::duration<double>& spectral_duration,
std::chrono::duration<double>& fgrad_duration,
const bool constant_step_size,
const Real<FPP> step_size,
const StoppingCriterion<Real<FPP>>& sc,
Real<FPP> &error,
const bool use_csr,
const bool packing_RL,
const int prod_mod,
Real<FPP> &c,
const Real<FPP>& lambda);
} }
#include "faust_palm4msa2020.hpp" #include "faust_palm4msa2020.hpp"
#include "faust_palm4msa2020_2.hpp" #include "faust_palm4msa2020_2.hpp"
......
...@@ -16,9 +16,7 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -16,9 +16,7 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
const bool on_gpu /*=false*/, const bool on_gpu /*=false*/,
const bool is_verbose/*=false*/, const int id/*=0*/) const bool is_verbose/*=false*/, const int id/*=0*/)
{ {
std::chrono::time_point<std::chrono::high_resolution_clock> spectral_stop, spectral_start; std::chrono::duration<double> norm2_duration = std::chrono::duration<double>::zero();
std::chrono::duration<double> spectral_duration = std::chrono::duration<double>::zero();
std::chrono::time_point<std::chrono::high_resolution_clock> fgrad_stop, fgrad_start;
std::chrono::duration<double> fgrad_duration = std::chrono::duration<double>::zero(); std::chrono::duration<double> fgrad_duration = std::chrono::duration<double>::zero();
int prod_mod = ORDER_ALL_BEST_MIXED; int prod_mod = ORDER_ALL_BEST_MIXED;
double norm1, norm2; double norm1, norm2;
...@@ -36,10 +34,10 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -36,10 +34,10 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
for(auto c: constraints) for(auto c: constraints)
dims.push_back(make_pair(c->get_rows(), c->get_cols())); dims.push_back(make_pair(c->get_rows(), c->get_cols()));
//TODO: make it possible to receive a MatSparse A //TODO: make it possible to receive a MatSparse A
if(is_verbose) if(is_verbose && mhtp_params.used)
{ {
std::cout << "use_MHTP: " << mhtp_params.used << std::endl; std::cout << mhtp_params.constant_step_size << std::endl;
std::cout<<"MHTP stop crit.: "<< std::endl << mhtp_params.sc.to_string() <<std::endl; std::cout << mhtp_params.to_string() << std::endl;
} }
Faust::MatDense<FPP,DEVICE> A_H = A; Faust::MatDense<FPP,DEVICE> A_H = A;
A_H.adjoint(); A_H.adjoint();
...@@ -49,7 +47,6 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -49,7 +47,6 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
std::function<void()> init_ite, next_fid; std::function<void()> init_ite, next_fid;
std::function<bool()> updating_facs; std::function<bool()> updating_facs;
std::function<bool()> is_last_fac_updated; std::function<bool()> is_last_fac_updated;
std::function<void(Faust::MatGeneric<FPP,DEVICE>*, int f_id)> update_fac;
// packed Fausts corresponding to each factor // packed Fausts corresponding to each factor
std::vector<TransformHelper<FPP,DEVICE>*> pL, pR; std::vector<TransformHelper<FPP,DEVICE>*> pL, pR;
pL.resize(nfacts);// pL[i] is the Faust for all factors to the left of the factor *(S.begin()+i) pL.resize(nfacts);// pL[i] is the Faust for all factors to the left of the factor *(S.begin()+i)
...@@ -126,72 +123,7 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -126,72 +123,7 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
updating_facs = [&f_id, &nfacts]() {return f_id < nfacts;}; updating_facs = [&f_id, &nfacts]() {return f_id < nfacts;};
is_last_fac_updated = [&f_id, &nfacts]() {return f_id == nfacts-1;}; is_last_fac_updated = [&f_id, &nfacts]() {return f_id == nfacts-1;};
} }
update_fac = [&A, &D, &spD, &constant_step_size, &is_verbose, &spectral_start, &spectral_stop, &spectral_duration,
&fgrad_start, &fgrad_stop, &fgrad_duration,
&c, &lipschitz_multiplicator, &sc, &error, &prod_mod, &packing_RL, &constraints,
&pR, &pL, &norm2_max_iter, &norm2_threshold, &norm2_flag, &lambda, &S, &scur_fac, &dcur_fac,
&use_csr]
(Faust::MatGeneric<FPP,DEVICE>* cur_fac, int f_id)
{
Real<FPP> nR=1,nL=1;
if(! constant_step_size)
{
if(is_verbose)
spectral_start = std::chrono::high_resolution_clock::now();
if(pR[f_id]->size() > 0)
nR = pR[f_id]->spectralNorm(norm2_max_iter, norm2_threshold, norm2_flag);
if(pL[f_id]->size() > 0)
nL = pL[f_id]->spectralNorm(norm2_max_iter, norm2_threshold, norm2_flag);
if(is_verbose)
{
spectral_stop = std::chrono::high_resolution_clock::now();
spectral_duration += spectral_stop-spectral_start;
}
c = lipschitz_multiplicator*lambda*lambda*nR*nR*nL*nL;
}
if(S.is_fact_sparse(f_id))
{
scur_fac = dynamic_cast<Faust::MatSparse<FPP,DEVICE>*>(cur_fac);
D = *scur_fac;
dcur_fac = nullptr;
}
else
{
dcur_fac = dynamic_cast<Faust::MatDense<FPP,DEVICE>*>(cur_fac); // TOFIX: possible it is not Dense... but MatDiag (sanity check on function start)
D = *dcur_fac;
scur_fac = nullptr;
}
if(is_verbose)
fgrad_start = std::chrono::high_resolution_clock::now();
if(typeid(D) != typeid(Faust::MatDense<FPP,Cpu>))
// compute_n_apply_grad2 is not yet supported with GPU2
compute_n_apply_grad1(f_id, A, S, pL, pR, lambda, c, D, sc, error, prod_mod, packing_RL);
else
compute_n_apply_grad2(f_id, A, S, pL, pR, lambda, c, D, sc, error, prod_mod, packing_RL);
if(is_verbose)
{
fgrad_stop = std::chrono::high_resolution_clock::now();
fgrad_duration += fgrad_stop-fgrad_start;
}
// really update now
constraints[f_id]->project<FPP,DEVICE,Real<FPP>>(D);
if(use_csr && dcur_fac != nullptr || !use_csr && scur_fac != nullptr)
throw std::runtime_error("Current factor is inconsistent with use_csr.");
if(use_csr)
{
spD = D;
S.update(spD, f_id); // update is at higher level than a simple assignment
}
else
{
S.update(D, f_id);
}
};
while(sc.do_continue(i, error)) while(sc.do_continue(i, error))
{ {
// std::cout << "i: " << i << std::endl; // std::cout << "i: " << i << std::endl;
...@@ -202,58 +134,30 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -202,58 +134,30 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
{ {
// std::cout << "#f_id: " << f_id << std::endl; // std::cout << "#f_id: " << f_id << std::endl;
cur_fac = S.get_gen_fact_nonconst(f_id); cur_fac = S.get_gen_fact_nonconst(f_id);
if(i%1000 == 0 && mhtp_params.used) if(mhtp_params.used && i%mhtp_params.palm4msa_period == 0)
{ {
if(is_verbose) perform_MHTP(mhtp_params, cur_fac, f_id, A, A_H, S, pL, pR,
std::cout << "MHTP" << std::endl; is_verbose, constraints, norm2_max_iter, norm2_threshold,
j = 0; norm2_duration,
// set the factor to zero fgrad_duration,
cur_fac->setZeros(); sc, error, use_csr, packing_RL, prod_mod, c, lambda);
while(mhtp_params.sc.do_continue(j)) // TODO: what about the error stop criterion?
{
if(mhtp_params.constant_step_size)
{
//TODO: add arguments to update_fac in order to avoid shunting PALM4MSA parameters with MHTP's
constant_step_size = true;
step_size = mhtp_params.step_size;
c = 1 / step_size;
}
update_fac(cur_fac, f_id);
j++;
Faust::MatDense<FPP,DEVICE> A_H_S = S.multiply(A_H);
Real<FPP> trr = std::real(A_H_S.trace());
Real<FPP> n = S.normFro();
if(std::numeric_limits<Real<FPP>>::epsilon() >= n)
throw std::runtime_error("Faust Frobenius norm is zero, can't compute lambda.");
lambda = trr/(n*n); //TODO: raise exception if n == 0
}
if(mhtp_params.constant_step_size) //TODO: cf above
constant_step_size = false;
if(is_verbose)
std::cout << "end MHTP" << std::endl;
} }
else else
update_fac(cur_fac, f_id); update_fact(cur_fac, f_id, A, S, pL, pR,
is_verbose, constraints, norm2_max_iter, norm2_threshold,
norm2_duration,
fgrad_duration,
constant_step_size, step_size,
sc, error, use_csr, packing_RL, prod_mod, c, lambda);
next_fid(); // f_id updated to iteration factor index (pL or pR too) next_fid(); // f_id updated to iteration factor index (pL or pR too)
} }
//update lambda //update lambda
//TODO: variable decl in parent scope update_lambda(S, A_H, lambda);
Faust::MatDense<FPP,DEVICE> A_H_S = S.multiply(A_H);
// auto last_Sfac_vec = { *(S.begin()+nfacts-1), dynamic_cast<Faust::MatGeneric<FPP,DEVICE>*>(&A_H)};
// Faust::TransformHelper<FPP,DEVICE> A_H_S_(*pL[nfacts-1], last_Sfac_vec);
// A_H_S_.disable_dtor();
// Faust::MatDense<FPP,DEVICE> A_H_S = A_H_S_.get_product();
Real<FPP> trr = std::real(A_H_S.trace());
Real<FPP> n = S.normFro();
if(std::numeric_limits<Real<FPP>>::epsilon() >= n)
throw std::runtime_error("Faust Frobenius norm is zero, can't compute lambda.");
lambda = trr/(n*n); //TODO: raise exception if n == 0
// std::cout << "debug lambda: " << lambda << std::endl;
if(is_verbose) if(is_verbose)
{ {
set_calc_err_ite_period(); //macro setting the variable ite_period set_calc_err_ite_period(); //macro setting the variable ite_period
if(! (i%ite_period)) if((! (i%ite_period)) && ite_period > 0)
{ {
std::cout << "PALM4MSA2020 iteration: " << i; std::cout << "PALM4MSA2020 iteration: " << i;
auto err = calc_rel_err(S, A, lambda); auto err = calc_rel_err(S, A, lambda);
...@@ -274,7 +178,7 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -274,7 +178,7 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
} }
if(is_verbose) if(is_verbose)
{ {
std::cout << "palm4msa spectral time=" << spectral_duration.count() << std::endl; std::cout << "palm4msa spectral time=" << norm2_duration.count() << std::endl;
std::cout << "palm4msa fgrad time=" << fgrad_duration.count() << std::endl; std::cout << "palm4msa fgrad time=" << fgrad_duration.count() << std::endl;
} }
} }
...@@ -400,4 +304,109 @@ void Faust::compute_n_apply_grad2(const int f_id, const Faust::MatDense<FPP,DEVI ...@@ -400,4 +304,109 @@ void Faust::compute_n_apply_grad2(const int f_id, const Faust::MatDense<FPP,DEVI
} }
} }
template<typename FPP, FDevice DEVICE>
void Faust::update_lambda(Faust::TransformHelper<FPP,DEVICE>& S, const MatDense<FPP, DEVICE> A_H, Real<FPP>& lambda)
{
Faust::MatDense<FPP,DEVICE> A_H_S = S.multiply(A_H);
Real<FPP> trr = std::real(A_H_S.trace());
Real<FPP> n = S.normFro();
if(std::numeric_limits<Real<FPP>>::epsilon() >= n)
throw std::runtime_error("Error in update_lambda: Faust Frobenius norm is zero, can't compute lambda.");
lambda = trr/(n*n);
}
template<typename FPP, FDevice DEVICE>
void Faust::update_fact(
Faust::MatGeneric<FPP,DEVICE>* cur_fac,
int f_id,
const Faust::MatDense<FPP,DEVICE>& A,
Faust::TransformHelper<FPP,DEVICE>& S,
std::vector<TransformHelper<FPP,DEVICE>*> &pL,
std::vector<TransformHelper<FPP,DEVICE>*> &pR,
const bool is_verbose,
std::vector<Faust::ConstraintGeneric*> & constraints,
const int norm2_max_iter,
const Real<FPP>& norm2_threshold,
std::chrono::duration<double>& norm2_duration,
std::chrono::duration<double>& fgrad_duration,
const bool constant_step_size,
const Real<FPP> step_size,
const StoppingCriterion<Real<FPP>>& sc,
Real<FPP> &error,
const bool use_csr,
const bool packing_RL,
const int prod_mod,
Real<FPP> &c,
const Real<FPP>& lambda)
{
int norm2_flag;
std::chrono::time_point<std::chrono::high_resolution_clock> spectral_stop, spectral_start;
std::chrono::time_point<std::chrono::high_resolution_clock> fgrad_stop, fgrad_start;
Faust::MatSparse<FPP,DEVICE>* scur_fac = nullptr;
Faust::MatDense<FPP,DEVICE>* dcur_fac = nullptr;
Faust::MatDense<FPP,DEVICE> D;
Faust::MatSparse<FPP,DEVICE> spD;
Real<FPP> nR=1,nL=1;
if(constant_step_size)
c = 1 / step_size;
else
{
if(is_verbose)
spectral_start = std::chrono::high_resolution_clock::now();
if(pR[f_id]->size() > 0)
nR = pR[f_id]->spectralNorm(norm2_max_iter, norm2_threshold, norm2_flag);
if(pL[f_id]->size() > 0)
nL = pL[f_id]->spectralNorm(norm2_max_iter, norm2_threshold, norm2_flag);
if(is_verbose)
{
spectral_stop = std::chrono::high_resolution_clock::now();
norm2_duration += spectral_stop-spectral_start;
}
c = LIPSCHITZ_MULTIPLICATOR*lambda*lambda*nR*nR*nL*nL;
}
if(S.is_fact_sparse(f_id))
{
scur_fac = dynamic_cast<Faust::MatSparse<FPP,DEVICE>*>(cur_fac);
D = *scur_fac;
dcur_fac = nullptr;
}
else
{
dcur_fac = dynamic_cast<Faust::MatDense<FPP,DEVICE>*>(cur_fac); // TOFIX: possible it is not Dense... but MatDiag (sanity check on function start)
D = *dcur_fac;
scur_fac = nullptr;
}
if(is_verbose)
fgrad_start = std::chrono::high_resolution_clock::now();
if(typeid(D) != typeid(Faust::MatDense<FPP,Cpu>))
// compute_n_apply_grad2 is not yet supported with GPU2
compute_n_apply_grad1(f_id, A, S, pL, pR, lambda, c, D, sc, error, prod_mod, packing_RL);
else
compute_n_apply_grad2(f_id, A, S, pL, pR, lambda, c, D, sc, error, prod_mod, packing_RL);
if(is_verbose)
{
fgrad_stop = std::chrono::high_resolution_clock::now();
fgrad_duration += fgrad_stop-fgrad_start;
}
// really update now
constraints[f_id]->project<FPP,DEVICE,Real<FPP>>(D);
if(use_csr && dcur_fac != nullptr || !use_csr && scur_fac != nullptr)
throw std::runtime_error("Current factor is inconsistent with use_csr.");
if(use_csr)
{
spD = D;
S.update(spD, f_id); // update is at higher level than a simple assignment
}
else
{
S.update(D, f_id);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment