From 6c160c5bb3600badaa598293b2d57b5101c29e19 Mon Sep 17 00:00:00 2001 From: hhakim <hakim.hadj-djilani@inria.fr> Date: Tue, 29 Jun 2021 11:56:33 +0200 Subject: [PATCH] Fix a segfault in PALM4MSA2020 when packing_RL == false. It was wrongly assumed that the matrix is a MatDense but it can be a MatSparse in which case the casting to MatDense gives a nullptr and a segfault. --- .../factorization/faust_palm4msa2020.h | 5 +++-- .../factorization/faust_palm4msa2020_2.hpp | 21 ++++++++++--------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/algorithm/factorization/faust_palm4msa2020.h b/src/algorithm/factorization/faust_palm4msa2020.h index e825b2f52..b28a60499 100644 --- a/src/algorithm/factorization/faust_palm4msa2020.h +++ b/src/algorithm/factorization/faust_palm4msa2020.h @@ -103,10 +103,10 @@ namespace Faust // warning: before calling compute_n_apply_grad*() out must be initialized to S[f_id] : the factor to update // TODO: ideally compute_n_apply_grad1 has no reason to be kept, compute_n_apply_grad2 is faster (but just in case I decided to keep it for a moment) template <typename FPP, FDevice DEVICE> - void compute_n_apply_grad1(const int f_id, const MatDense<FPP,DEVICE> &A, TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const Real<FPP>& lambda, const Real<FPP>& c, MatDense<FPP,DEVICE> &out /* D */, const StoppingCriterion<Real<FPP>>& sc, Real<FPP> &error, const int prod_mod); + void compute_n_apply_grad1(const int f_id, const MatDense<FPP,DEVICE> &A, TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const bool packing_RL, const Real<FPP>& lambda, const Real<FPP>& c, MatDense<FPP,DEVICE> &out /* D */, const StoppingCriterion<Real<FPP>>& sc, Real<FPP> &error, const int prod_mod); template <typename FPP, FDevice DEVICE> - void compute_n_apply_grad2(const int f_id, const MatDense<FPP,DEVICE> &A, TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const Real<FPP>& lambda, const Real<FPP> &c, MatDense<FPP,DEVICE> &out /* D */, const StoppingCriterion<Real<FPP>>& sc, Real<FPP> &error, const int prod_mod); + void compute_n_apply_grad2(const int f_id, const MatDense<FPP,DEVICE> &A, TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const bool packing_RL, const Real<FPP>& lambda, const Real<FPP> &c, MatDense<FPP,DEVICE> &out /* D */, const StoppingCriterion<Real<FPP>>& sc, Real<FPP> &error, const int prod_mod); template<typename FPP, FDevice DEVICE> Real<FPP> calc_rel_err(const TransformHelper<FPP,DEVICE>& S, const MatDense<FPP,DEVICE> &A, const Real<FPP> &lambda=1, const Real<FPP>* A_norm=nullptr); @@ -129,6 +129,7 @@ namespace Faust Faust::TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*> &pR, + const bool packing_RL, const bool is_verbose, const Faust::ConstraintGeneric &constraints, const int norm2_max_iter, diff --git a/src/algorithm/factorization/faust_palm4msa2020_2.hpp b/src/algorithm/factorization/faust_palm4msa2020_2.hpp index b6e101707..b4beb305a 100644 --- a/src/algorithm/factorization/faust_palm4msa2020_2.hpp +++ b/src/algorithm/factorization/faust_palm4msa2020_2.hpp @@ -143,14 +143,14 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A, cur_fac = S.get_gen_fact_nonconst(f_id); if(mhtp_params.used && i%mhtp_params.palm4msa_period == 0) { - perform_MHTP(mhtp_params, A, A_H, S, f_id, pL, pR, + perform_MHTP(mhtp_params, A, A_H, S, f_id, pL, pR, packing_RL, is_verbose, *constraints[f_id], norm2_max_iter, norm2_threshold, norm2_duration, fgrad_duration, sc, error, factors_format, prod_mod, c, lambda); } else - update_fact(cur_fac, f_id, A, S, pL, pR, + update_fact(cur_fac, f_id, A, S, pL, pR, packing_RL, is_verbose, *constraints[f_id], norm2_max_iter, norm2_threshold, norm2_duration, fgrad_duration, @@ -192,7 +192,7 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A, } template <typename FPP, FDevice DEVICE> -void Faust::compute_n_apply_grad1(const int f_id, const Faust::MatDense<FPP,DEVICE> &A, Faust::TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const Real<FPP>& lambda, const Real<FPP> &c, Faust::MatDense<FPP,DEVICE> &out /* D */, const StoppingCriterion<Real<FPP>>& sc, Real<FPP> &error, const int prod_mod) +void Faust::compute_n_apply_grad1(const int f_id, const Faust::MatDense<FPP,DEVICE> &A, Faust::TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const bool packing_RL, const Real<FPP>& lambda, const Real<FPP> &c, Faust::MatDense<FPP,DEVICE> &out /* D */, const StoppingCriterion<Real<FPP>>& sc, Real<FPP> &error, const int prod_mod) { Faust::MatDense<FPP,DEVICE> tmp; Faust::MatDense<FPP,DEVICE> & D = out; @@ -211,7 +211,7 @@ void Faust::compute_n_apply_grad1(const int f_id, const Faust::MatDense<FPP,DEVI auto pL_sz = pL[f_id]->size(); if(pR_sz > 0) { - if(pR_sz == 1) // packing_RL == true + if(pR_sz == 1 && packing_RL) // packing_RL == true LorR = dynamic_cast<Faust::MatDense<FPP,DEVICE>*>(pR[f_id]->get_gen_fact_nonconst(0)); //normally pR[f_id] is packed (hence reduced to a single MatDense) else { @@ -229,7 +229,7 @@ void Faust::compute_n_apply_grad1(const int f_id, const Faust::MatDense<FPP,DEVI } if(pL_sz > 0) { - if(pL_sz == 1) // packing_RL == true + if(pL_sz == 1 && packing_RL) // packing_RL == true LorR = dynamic_cast<Faust::MatDense<FPP,DEVICE>*>(pL[f_id]->get_gen_fact_nonconst(0)); else { @@ -243,7 +243,7 @@ void Faust::compute_n_apply_grad1(const int f_id, const Faust::MatDense<FPP,DEVI } template <typename FPP, FDevice DEVICE> -void Faust::compute_n_apply_grad2(const int f_id, const Faust::MatDense<FPP,DEVICE> &A, Faust::TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const Real<FPP>& lambda, const Real<FPP> &c, Faust::MatDense<FPP,DEVICE> &out /* D */, const StoppingCriterion<Real<FPP>>& sc, Real<FPP> &error, const int prod_mod) +void Faust::compute_n_apply_grad2(const int f_id, const Faust::MatDense<FPP,DEVICE> &A, Faust::TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const bool packing_RL, const Real<FPP>& lambda, const Real<FPP> &c, Faust::MatDense<FPP,DEVICE> &out /* D */, const StoppingCriterion<Real<FPP>>& sc, Real<FPP> &error, const int prod_mod) { Faust::MatDense<FPP,DEVICE> tmp; Faust::MatDense<FPP,DEVICE> grad_over_c; @@ -260,7 +260,7 @@ void Faust::compute_n_apply_grad2(const int f_id, const Faust::MatDense<FPP,DEVI auto pL_sz = pL[f_id]->size(); if(pR_sz > 0) { - if(pR_sz == 1) // packing_RL == true + if(pR_sz == 1 && packing_RL) _R = dynamic_cast<Faust::MatDense<FPP,DEVICE>*>(pR[f_id]->get_gen_fact_nonconst(0)); else { @@ -271,7 +271,7 @@ void Faust::compute_n_apply_grad2(const int f_id, const Faust::MatDense<FPP,DEVI } if(pL_sz > 0) { - if(pL_sz == 1) // packing_RL == true + if(pL_sz == 1 && packing_RL) _L = dynamic_cast<Faust::MatDense<FPP,DEVICE>*>(pL[f_id]->get_gen_fact_nonconst(0)); else { @@ -364,6 +364,7 @@ void Faust::update_fact( Faust::TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*> &pR, + const bool packing_RL, const bool is_verbose, const Faust::ConstraintGeneric &constraint, const int norm2_max_iter, @@ -422,9 +423,9 @@ void Faust::update_fact( fgrad_start = std::chrono::high_resolution_clock::now(); if(typeid(D) != typeid(Faust::MatDense<FPP,Cpu>)) // compute_n_apply_grad2 is not yet supported with GPU2 - compute_n_apply_grad1(f_id, A, S, pL, pR, lambda, c, D, sc, error, prod_mod); + compute_n_apply_grad1(f_id, A, S, pL, pR, packing_RL, lambda, c, D, sc, error, prod_mod); else - compute_n_apply_grad2(f_id, A, S, pL, pR, lambda, c, D, sc, error, prod_mod); + compute_n_apply_grad2(f_id, A, S, pL, pR, packing_RL, lambda, c, D, sc, error, prod_mod); if(is_verbose) { fgrad_stop = std::chrono::high_resolution_clock::now(); -- GitLab