Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 4b6f8c08 authored by hhakim's avatar hhakim
Browse files

Add hidden options (through environment variables) in PALM4MSA 2020 (no_lambda_error, use_grad1).

parent 92349594
No related branches found
No related tags found
No related merge requests found
...@@ -120,7 +120,7 @@ namespace Faust ...@@ -120,7 +120,7 @@ namespace Faust
* \param lambda: the output of the lambda computed by the function. * \param lambda: the output of the lambda computed by the function.
*/ */
template<typename FPP, FDevice DEVICE> template<typename FPP, FDevice DEVICE>
void update_lambda(Faust::TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const MatDense<FPP, DEVICE> &A_H, Real<FPP>& lambda); void update_lambda(Faust::TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const MatDense<FPP, DEVICE> &A_H, Real<FPP>& lambda, bool no_lambda_error=false);
template<typename FPP, FDevice DEVICE> template<typename FPP, FDevice DEVICE>
void update_fact( void update_fact(
...@@ -144,7 +144,9 @@ namespace Faust ...@@ -144,7 +144,9 @@ namespace Faust
const FactorsFormat factors_format, const FactorsFormat factors_format,
const int prod_mod, const int prod_mod,
Real<FPP> &c, Real<FPP> &c,
const Real<FPP>& lambda); const Real<FPP>& lambda,
bool use_grad1=false);
} }
......
...@@ -18,11 +18,21 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -18,11 +18,21 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
{ {
std::chrono::duration<double> norm2_duration = std::chrono::duration<double>::zero(); std::chrono::duration<double> norm2_duration = std::chrono::duration<double>::zero();
std::chrono::duration<double> fgrad_duration = std::chrono::duration<double>::zero(); std::chrono::duration<double> fgrad_duration = std::chrono::duration<double>::zero();
double norm1, norm2;
/* variable environment parameters (which are interesting enough for debugging/profiling but not yet for the wrapper user API */
char* str_env_prod_mod = getenv("PROD_MOD"); char* str_env_prod_mod = getenv("PROD_MOD");
int prod_mod = DYNPROG; // GREEDY_ALL_BEST_GENMAT; DYNPROG is a bit better to factorize the MEG matrix and not slower to factorize a Hadamard matrix int prod_mod = DYNPROG; // GREEDY_ALL_BEST_GENMAT; DYNPROG is a bit better to factorize the MEG matrix and not slower to factorize a Hadamard matrix
if(str_env_prod_mod) if(str_env_prod_mod)
prod_mod = std::atoi(str_env_prod_mod); prod_mod = std::atoi(str_env_prod_mod);
double norm1, norm2; auto str_env_no_lambda_error = getenv("NO_LAMBDA_ERROR");
bool no_lambda_error = false;
if(str_env_no_lambda_error)
no_lambda_error = (bool) std::atoi(str_env_no_lambda_error);
bool use_grad1 = false;
auto str_env_use_grad1 = getenv("USE_GRAD1");
if(str_env_use_grad1)
use_grad1 = (bool) std::atoi(str_env_use_grad1);
/******************************************************/
// std::cout << "palm4msa2 "<< std::endl; // std::cout << "palm4msa2 "<< std::endl;
if(constraints.size() == 0) if(constraints.size() == 0)
throw out_of_range("No constraint passed to palm4msa."); throw out_of_range("No constraint passed to palm4msa.");
...@@ -158,12 +168,12 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -158,12 +168,12 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
norm2_duration, norm2_duration,
fgrad_duration, fgrad_duration,
constant_step_size, step_size, constant_step_size, step_size,
sc, error, factors_format, prod_mod, c, lambda); sc, error, factors_format, prod_mod, c, lambda, use_grad1);
next_fid(); // f_id updated to iteration factor index (pL or pR too) next_fid(); // f_id updated to iteration factor index (pL or pR too)
} }
//update lambda //update lambda
update_lambda(S, pL, pR, A_H, lambda); update_lambda(S, pL, pR, A_H, lambda, no_lambda_error);
if(is_verbose) if(is_verbose)
{ {
set_calc_err_ite_period(); //macro setting the variable ite_period set_calc_err_ite_period(); //macro setting the variable ite_period
...@@ -332,7 +342,7 @@ void Faust::compute_n_apply_grad2(const int f_id, const Faust::MatDense<FPP,DEVI ...@@ -332,7 +342,7 @@ void Faust::compute_n_apply_grad2(const int f_id, const Faust::MatDense<FPP,DEVI
} }
template<typename FPP, FDevice DEVICE> template<typename FPP, FDevice DEVICE>
void Faust::update_lambda(Faust::TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const MatDense<FPP, DEVICE> &A_H, Real<FPP>& lambda) void Faust::update_lambda(Faust::TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const MatDense<FPP, DEVICE> &A_H, Real<FPP>& lambda, bool no_lambda_error/*= false*/)
{ {
Faust::MatDense<FPP,DEVICE> A_H_S; Faust::MatDense<FPP,DEVICE> A_H_S;
MatDense<FPP, DEVICE> S_mat; MatDense<FPP, DEVICE> S_mat;
...@@ -365,7 +375,14 @@ void Faust::update_lambda(Faust::TransformHelper<FPP,DEVICE>& S, std::vector<Tra ...@@ -365,7 +375,14 @@ void Faust::update_lambda(Faust::TransformHelper<FPP,DEVICE>& S, std::vector<Tra
tr = A_H_S.trace(); tr = A_H_S.trace();
nS = S_mat.norm(); nS = S_mat.norm();
if(std::numeric_limits<Real<FPP>>::epsilon() >= nS) if(std::numeric_limits<Real<FPP>>::epsilon() >= nS)
throw std::runtime_error("Error in update_lambda: S Frobenius norm is zero, can't compute lambda."); if(no_lambda_error)
{
// don't change lambda
std::cout << "WARNING: lambda didn't change because S Fro. norm is zero." << std::endl;
return;
}
else
throw std::runtime_error("Error in update_lambda: S Frobenius norm is zero, can't compute lambda.");
lambda = std::real(tr)/(nS*nS); lambda = std::real(tr)/(nS*nS);
} }
...@@ -391,7 +408,8 @@ void Faust::update_fact( ...@@ -391,7 +408,8 @@ void Faust::update_fact(
const FactorsFormat factors_format, const FactorsFormat factors_format,
const int prod_mod, const int prod_mod,
Real<FPP> &c, Real<FPP> &c,
const Real<FPP>& lambda) const Real<FPP>& lambda,
bool use_grad1/*= false*/)
{ {
int norm2_flag; int norm2_flag;
std::chrono::time_point<std::chrono::high_resolution_clock> spectral_stop, spectral_start; std::chrono::time_point<std::chrono::high_resolution_clock> spectral_stop, spectral_start;
...@@ -434,7 +452,7 @@ void Faust::update_fact( ...@@ -434,7 +452,7 @@ void Faust::update_fact(
if(is_verbose) if(is_verbose)
fgrad_start = std::chrono::high_resolution_clock::now(); fgrad_start = std::chrono::high_resolution_clock::now();
if(typeid(D) != typeid(Faust::MatDense<FPP,Cpu>)) if(typeid(D) != typeid(Faust::MatDense<FPP,Cpu>) || use_grad1)
// compute_n_apply_grad2 is not yet supported with GPU2 // compute_n_apply_grad2 is not yet supported with GPU2
compute_n_apply_grad1(f_id, A, S, pL, pR, packing_RL, lambda, c, D, sc, error, prod_mod); compute_n_apply_grad1(f_id, A, S, pL, pR, packing_RL, lambda, c, D, sc, error, prod_mod);
else else
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment