-
hhakim authored
Rename Device type to FDevice to avoid conflicts with torchlib (both Faust and libtorch use this name type out of namespace), remove also 'using namespace ...' in some headers (where they shouldn't be).
hhakim authoredRename Device type to FDevice to avoid conflicts with torchlib (both Faust and libtorch use this name type out of namespace), remove also 'using namespace ...' in some headers (where they shouldn't be).
faust_hierarchical.hpp 5.92 KiB
template<typename FPP, FDevice DEVICE>
Faust::TransformHelper<FPP,DEVICE>* Faust::hierarchical(const Faust::MatDense<FPP,DEVICE>& A,
Params<FPP,DEVICE, Real<FPP>> & p,
Real<FPP>& lambda, const bool compute_2norm_on_array)
{
auto S = new Faust::TransformHelper<FPP,DEVICE>(); // A is copied
S->push_back(&A);
Faust::MatGeneric<FPP,DEVICE> *Si;
const Faust::ConstraintGeneric *fac_cons, *res_cons;
Faust::MatGeneric<FPP,DEVICE> *zero_mat, *id_mat;
Faust::MatDense<FPP,DEVICE> * tmp_dense;
Faust::MatSparse<FPP,DEVICE> * tmp_sparse;
std::vector<Faust::MatGeneric<FPP,DEVICE>*> Si_vec;
std::vector<Faust::ConstraintGeneric*> Si_cons;
Real<FPP> lambda_ = p.init_lambda;
Real<FPP> glo_lambda = 1;
std::vector<const Faust::ConstraintGeneric*> & fac_constraints = p.cons[0];
std::vector<const Faust::ConstraintGeneric*> & res_constraints = p.cons[1];
//TODO: remove these local variables and use directly p.
const bool is_update_way_R2L = p.isUpdateWayR2L;
const bool is_fact_side_left = p.isFactSideLeft;
const bool use_csr = p.use_csr;
const bool packing_RL = p.packing_RL;
const Real<FPP> norm2_threshold = p.norm2_threshold;
const unsigned int norm2_max_iter = p.norm2_max_iter;
const double step_size = p.step_size;
const bool constant_step_size = p.isConstantStepSize;
if(p.isVerbose) p.Display();
for(int i=0;i < fac_constraints.size();i++)
{
cout << "Faust::hierarchical: " << i+1 << endl;
if(is_fact_side_left)
{
Si = S->get_gen_fact_nonconst(0);
}
else
Si = S->get_gen_fact_nonconst(i);
fac_cons = fac_constraints[i];
res_cons = res_constraints[i];
// init factors for the local optimization (factorization of Si)
//TODO: refactor into a separate function init_zero_id
if(use_csr)
{
tmp_sparse = new Faust::MatSparse<FPP,DEVICE>(fac_constraints[i]->get_rows(), fac_constraints[i]->get_cols());
tmp_sparse->setZeros();
zero_mat = tmp_sparse;
tmp_sparse = new Faust::MatSparse<FPP,DEVICE>(res_constraints[i]->get_rows(), res_constraints[i]->get_cols());
tmp_sparse->setEyes();
id_mat = tmp_sparse;
}
else
{
tmp_dense = new Faust::MatDense<FPP,DEVICE>(fac_constraints[i]->get_rows(), fac_constraints[i]->get_cols());
tmp_dense->setZeros();
zero_mat = tmp_dense;
tmp_dense = new Faust::MatDense<FPP,DEVICE>(res_constraints[i]->get_rows(), res_constraints[i]->get_cols());
tmp_dense->setEyes();
id_mat = tmp_dense;
}
if(is_update_way_R2L)
Si_vec = {id_mat, zero_mat};
else
Si_vec = {zero_mat, id_mat};
Si_cons = { const_cast<Faust::ConstraintGeneric*>(fac_cons), const_cast<Faust::ConstraintGeneric*>(res_cons) };
Faust::TransformHelper<FPP,DEVICE> Si_th(Si_vec, 1.0, false, false, true);
lambda_ = 1;
tmp_dense = dynamic_cast<Faust::MatDense<FPP,DEVICE>*>(Si);
if(tmp_dense == nullptr)
{
tmp_sparse = dynamic_cast<Faust::MatSparse<FPP,DEVICE>*>(Si);
tmp_dense = new MatDense<FPP,Cpu>(*tmp_sparse);
}
else tmp_sparse = nullptr;
Faust::palm4msa2(*tmp_dense, Si_cons, Si_th, lambda_, p.stop_crit_2facts, is_update_way_R2L , use_csr, packing_RL, compute_2norm_on_array,
norm2_threshold, norm2_max_iter, constant_step_size, step_size);
if(tmp_sparse != nullptr)
// the Si factor has been converted into a MatDense in the memory
// storage
// delete it // TODO: palm4msa2 should handle this on its own
delete tmp_dense;
//prepare global optimization
glo_lambda *= lambda_;
if(is_fact_side_left)
{
S->pop_front();
S->push_first(*(Si_th.begin()+1), false, false);
S->push_first(*(Si_th.begin()), false, false);
}
else
{
S->pop_back();
S->push_back(*(Si_th.begin()), false, false);
S->push_back(*(Si_th.begin()+1), false, false);
}
//TODO: verify if the constraints order doesn't depend on
//is_fact_side_left
std::vector<Faust::ConstraintGeneric*> glo_cons;
for(auto ite_cons=fac_constraints.begin(); ite_cons != fac_constraints.begin()+i+1;ite_cons++)
glo_cons.push_back(const_cast<Faust::ConstraintGeneric*>(*ite_cons));
glo_cons.push_back(const_cast<Faust::ConstraintGeneric*>(res_constraints[i]));
// global optimization
Faust::palm4msa2(A, glo_cons, *S, glo_lambda, p.stop_crit_global ,is_update_way_R2L, use_csr, packing_RL, compute_2norm_on_array,
norm2_threshold, norm2_max_iter, constant_step_size, step_size);
}
lambda = glo_lambda;
return S;
}
template<typename FPP, FDevice DEVICE>
Faust::TransformHelper<FPP,DEVICE>* Faust::hierarchical(const Faust::MatDense<FPP,DEVICE>& A,
// const int nites,
std::vector<StoppingCriterion<Real<FPP>>>& sc,
std::vector<const Faust::ConstraintGeneric*> & fac_constraints,
std::vector<const Faust::ConstraintGeneric*> & res_constraints,
Real<FPP>& lambda,
const bool is_update_way_R2L, const bool is_fact_side_left,
const bool use_csr, const bool packing_RL,
const bool compute_2norm_on_array,
const Real<FPP> norm2_threshold,
const unsigned int norm2_max_iter, const bool is_verbose,
const bool constant_step_size, const Real<FPP> step_size)
{
Faust::Params<FPP,DEVICE,Real<FPP>> p(A.getNbRow(), A.getNbCol(), fac_constraints.size()+1, {fac_constraints, res_constraints}, {}, sc[0], sc[1], is_verbose, is_update_way_R2L, is_fact_side_left, lambda, constant_step_size, step_size);
p.use_csr = use_csr;
p.packing_RL = packing_RL;
p.norm2_threshold = norm2_threshold;
p.norm2_max_iter = norm2_max_iter;
return Faust::hierarchical(A, p, lambda, compute_2norm_on_array);
}