Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 4b6f8c08 authored by hhakim's avatar hhakim
Browse files

Add hidden options (through environment variables) in PALM4MSA 2020 (no_lambda_error, use_grad1).

parent 92349594
Branches
Tags
No related merge requests found
......@@ -120,7 +120,7 @@ namespace Faust
* \param lambda: the output of the lambda computed by the function.
*/
template<typename FPP, FDevice DEVICE>
void update_lambda(Faust::TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const MatDense<FPP, DEVICE> &A_H, Real<FPP>& lambda);
void update_lambda(Faust::TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const MatDense<FPP, DEVICE> &A_H, Real<FPP>& lambda, bool no_lambda_error=false);
template<typename FPP, FDevice DEVICE>
void update_fact(
......@@ -144,7 +144,9 @@ namespace Faust
const FactorsFormat factors_format,
const int prod_mod,
Real<FPP> &c,
const Real<FPP>& lambda);
const Real<FPP>& lambda,
bool use_grad1=false);
}
......
......@@ -18,11 +18,21 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
{
std::chrono::duration<double> norm2_duration = std::chrono::duration<double>::zero();
std::chrono::duration<double> fgrad_duration = std::chrono::duration<double>::zero();
double norm1, norm2;
/* variable environment parameters (which are interesting enough for debugging/profiling but not yet for the wrapper user API */
char* str_env_prod_mod = getenv("PROD_MOD");
int prod_mod = DYNPROG; // GREEDY_ALL_BEST_GENMAT; DYNPROG is a bit better to factorize the MEG matrix and not slower to factorize a Hadamard matrix
if(str_env_prod_mod)
prod_mod = std::atoi(str_env_prod_mod);
double norm1, norm2;
auto str_env_no_lambda_error = getenv("NO_LAMBDA_ERROR");
bool no_lambda_error = false;
if(str_env_no_lambda_error)
no_lambda_error = (bool) std::atoi(str_env_no_lambda_error);
bool use_grad1 = false;
auto str_env_use_grad1 = getenv("USE_GRAD1");
if(str_env_use_grad1)
use_grad1 = (bool) std::atoi(str_env_use_grad1);
/******************************************************/
// std::cout << "palm4msa2 "<< std::endl;
if(constraints.size() == 0)
throw out_of_range("No constraint passed to palm4msa.");
......@@ -158,12 +168,12 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
norm2_duration,
fgrad_duration,
constant_step_size, step_size,
sc, error, factors_format, prod_mod, c, lambda);
sc, error, factors_format, prod_mod, c, lambda, use_grad1);
next_fid(); // f_id updated to iteration factor index (pL or pR too)
}
//update lambda
update_lambda(S, pL, pR, A_H, lambda);
update_lambda(S, pL, pR, A_H, lambda, no_lambda_error);
if(is_verbose)
{
set_calc_err_ite_period(); //macro setting the variable ite_period
......@@ -332,7 +342,7 @@ void Faust::compute_n_apply_grad2(const int f_id, const Faust::MatDense<FPP,DEVI
}
template<typename FPP, FDevice DEVICE>
void Faust::update_lambda(Faust::TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const MatDense<FPP, DEVICE> &A_H, Real<FPP>& lambda)
void Faust::update_lambda(Faust::TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const MatDense<FPP, DEVICE> &A_H, Real<FPP>& lambda, bool no_lambda_error/*= false*/)
{
Faust::MatDense<FPP,DEVICE> A_H_S;
MatDense<FPP, DEVICE> S_mat;
......@@ -365,7 +375,14 @@ void Faust::update_lambda(Faust::TransformHelper<FPP,DEVICE>& S, std::vector<Tra
tr = A_H_S.trace();
nS = S_mat.norm();
if(std::numeric_limits<Real<FPP>>::epsilon() >= nS)
throw std::runtime_error("Error in update_lambda: S Frobenius norm is zero, can't compute lambda.");
if(no_lambda_error)
{
// don't change lambda
std::cout << "WARNING: lambda didn't change because S Fro. norm is zero." << std::endl;
return;
}
else
throw std::runtime_error("Error in update_lambda: S Frobenius norm is zero, can't compute lambda.");
lambda = std::real(tr)/(nS*nS);
}
......@@ -391,7 +408,8 @@ void Faust::update_fact(
const FactorsFormat factors_format,
const int prod_mod,
Real<FPP> &c,
const Real<FPP>& lambda)
const Real<FPP>& lambda,
bool use_grad1/*= false*/)
{
int norm2_flag;
std::chrono::time_point<std::chrono::high_resolution_clock> spectral_stop, spectral_start;
......@@ -434,7 +452,7 @@ void Faust::update_fact(
if(is_verbose)
fgrad_start = std::chrono::high_resolution_clock::now();
if(typeid(D) != typeid(Faust::MatDense<FPP,Cpu>))
if(typeid(D) != typeid(Faust::MatDense<FPP,Cpu>) || use_grad1)
// compute_n_apply_grad2 is not yet supported with GPU2
compute_n_apply_grad1(f_id, A, S, pL, pR, packing_RL, lambda, c, D, sc, error, prod_mod);
else
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment