Mentions légales du service

Skip to content
Snippets Groups Projects
Commit a9d2e69a authored by hhakim's avatar hhakim
Browse files

In PALM4MSA 2020 impl. if custom factors are defined at initialization,...

In PALM4MSA 2020 impl. if custom factors are defined at initialization, convert them to the configured FactorFormat (AllSparse, AllDense).

- Add member functions convertToSparse(), convertToDense() in TransformHelper classes.

Fix assertion about consistency of factor format in update_fact() (when factors_format == AllSparse or AllDense).
Without the conversion this assertion was failing anyway.
parent 24f47cae
No related branches found
No related tags found
No related merge requests found
...@@ -43,6 +43,14 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -43,6 +43,14 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
A_H.adjoint(); A_H.adjoint();
if(S.size() != nfacts) if(S.size() != nfacts)
fill_of_eyes(S, nfacts, factors_format != AllDense, dims, on_gpu); fill_of_eyes(S, nfacts, factors_format != AllDense, dims, on_gpu);
else if(factors_format == AllSparse)
{
S.convertToSparse();
}
else if(factors_format == AllDense)
{
S.convertToDense();
}
int i = 0, f_id, j; int i = 0, f_id, j;
std::function<void()> init_ite, next_fid; std::function<void()> init_ite, next_fid;
std::function<bool()> updating_facs; std::function<bool()> updating_facs;
...@@ -409,7 +417,7 @@ void Faust::update_fact( ...@@ -409,7 +417,7 @@ void Faust::update_fact(
// D is the prox image (always a MatDense // D is the prox image (always a MatDense
// convert D to the proper format (MatSparse or MatDense) // convert D to the proper format (MatSparse or MatDense)
if(factors_format == AllSparse && dcur_fac != nullptr || factors_format != AllSparse && scur_fac != nullptr) if(factors_format == AllSparse && dcur_fac != nullptr || factors_format == AllDense && scur_fac != nullptr)
throw std::runtime_error("Current factor is inconsistent with the configured factors_format."); throw std::runtime_error("Current factor is inconsistent with the configured factors_format.");
if(factors_format == AllSparse) if(factors_format == AllSparse)
......
...@@ -208,6 +208,8 @@ namespace Faust ...@@ -208,6 +208,8 @@ namespace Faust
const MatGeneric<FPP,Cpu>* get_gen_fact(const faust_unsigned_int id) const; const MatGeneric<FPP,Cpu>* get_gen_fact(const faust_unsigned_int id) const;
void update(const MatGeneric<FPP,Cpu>& M, const faust_unsigned_int fact_id); void update(const MatGeneric<FPP,Cpu>& M, const faust_unsigned_int fact_id);
void replace(const MatGeneric<FPP, Cpu>* M, const faust_unsigned_int fact_id); void replace(const MatGeneric<FPP, Cpu>* M, const faust_unsigned_int fact_id);
void convertToSparse();
void convertToDense();
}; };
......
...@@ -718,6 +718,36 @@ template<typename FPP> ...@@ -718,6 +718,36 @@ template<typename FPP>
update_total_nnz(); update_total_nnz();
} }
template<typename FPP>
void TransformHelper<FPP, Cpu>::convertToSparse()
{
const MatDense<FPP,Cpu> * mat_dense;
const MatSparse<FPP,Cpu> * mat_sparse;
for(int i=0;i<this->size();i++)
{
if(mat_dense = dynamic_cast<const MatDense<FPP,Cpu>*>(this->get_gen_fact(i)))
{
mat_sparse = new MatSparse<FPP,Cpu>(*mat_dense);
this->replace(mat_sparse, i);
}
}
}
template<typename FPP>
void TransformHelper<FPP, Cpu>::convertToDense()
{
const MatDense<FPP,Cpu> * mat_dense;
const MatSparse<FPP,Cpu> * mat_sparse;
for(int i=0;i<this->size();i++)
{
if(mat_sparse = dynamic_cast<const MatSparse<FPP,Cpu>*>(this->get_gen_fact(i)))
{
mat_dense = new MatDense<FPP,Cpu>(*mat_sparse);
this->replace(mat_dense, i);
}
}
}
template<typename FPP> template<typename FPP>
void TransformHelper<FPP, Cpu>::replace(const MatGeneric<FPP, Cpu>* M, const faust_unsigned_int fact_id) void TransformHelper<FPP, Cpu>::replace(const MatGeneric<FPP, Cpu>* M, const faust_unsigned_int fact_id)
{ {
......
...@@ -19,7 +19,7 @@ namespace Faust ...@@ -19,7 +19,7 @@ namespace Faust
TransformHelper(TransformHelper<FPP,GPU2>* th, faust_unsigned_int* row_ids, faust_unsigned_int num_rows, faust_unsigned_int* col_ids, faust_unsigned_int num_cols); TransformHelper(TransformHelper<FPP,GPU2>* th, faust_unsigned_int* row_ids, faust_unsigned_int num_rows, faust_unsigned_int* col_ids, faust_unsigned_int num_cols);
TransformHelper(const std::vector<MatGeneric<FPP,GPU2> *>& facts, const FPP lambda_ = (FPP)1.0, const bool optimizedCopy=false, const bool cloning_fact = true, const bool internal_call=false); TransformHelper(const std::vector<MatGeneric<FPP,GPU2> *>& facts, const FPP lambda_ = (FPP)1.0, const bool optimizedCopy=false, const bool cloning_fact = true, const bool internal_call=false);
TransformHelper(const TransformHelper<FPP,Cpu>& cpu_t, int32_t dev_id=-1, void* stream=nullptr); TransformHelper(const TransformHelper<FPP,Cpu>& cpu_t, int32_t dev_id=-1, void* stream=nullptr);
TransformHelper(const TransformHelper<FPP,GPU2>& th, bool transpose, bool conjugate); TransformHelper(const TransformHelper<FPP,GPU2>& th, bool transpose, bool conjugate);
#ifndef IGNORE_TRANSFORM_HELPER_VARIADIC_TPL #ifndef IGNORE_TRANSFORM_HELPER_VARIADIC_TPL
template<typename ...GList> TransformHelper(GList& ... t); template<typename ...GList> TransformHelper(GList& ... t);
#endif #endif
...@@ -59,6 +59,9 @@ namespace Faust ...@@ -59,6 +59,9 @@ namespace Faust
void pack_factors(faust_unsigned_int start_id, faust_unsigned_int end_id, const int mul_order_opt_mode=DEFAULT); void pack_factors(faust_unsigned_int start_id, faust_unsigned_int end_id, const int mul_order_opt_mode=DEFAULT);
void update(const MatGeneric<FPP, GPU2>& M, const faust_unsigned_int id); void update(const MatGeneric<FPP, GPU2>& M, const faust_unsigned_int id);
void replace(const MatGeneric<FPP, GPU2>* M, const faust_unsigned_int id); void replace(const MatGeneric<FPP, GPU2>* M, const faust_unsigned_int id);
void convertToSparse();
void convertToDense();
void operator=(TransformHelper<FPP,GPU2>& th); void operator=(TransformHelper<FPP,GPU2>& th);
typename Transform<FPP,GPU2>::iterator begin() const; typename Transform<FPP,GPU2>::iterator begin() const;
typename Transform<FPP,GPU2>::iterator end() const; typename Transform<FPP,GPU2>::iterator end() const;
...@@ -79,7 +82,7 @@ namespace Faust ...@@ -79,7 +82,7 @@ namespace Faust
void set_FM_mul_mode(const int mul_order_opt_mode, const bool silent=false) const; void set_FM_mul_mode(const int mul_order_opt_mode, const bool silent=false) const;
void set_Fv_mul_mode(const int Fv_mul_mode) const; void set_Fv_mul_mode(const int Fv_mul_mode) const;
faust_unsigned_int get_total_nnz() const; faust_unsigned_int get_total_nnz() const;
// faust_unsigned_int get_fact_nnz(const faust_unsigned_int id) const; // faust_unsigned_int get_fact_nnz(const faust_unsigned_int id) const;
TransformHelper<FPP,GPU2>* normalize(const int meth /* 1 for 1-norm, 2 for 2-norm (2-norm), -1 for inf-norm */) const; TransformHelper<FPP,GPU2>* normalize(const int meth /* 1 for 1-norm, 2 for 2-norm (2-norm), -1 for inf-norm */) const;
TransformHelper<FPP,GPU2>* transpose(); TransformHelper<FPP,GPU2>* transpose();
TransformHelper<FPP,GPU2>* conjugate(); TransformHelper<FPP,GPU2>* conjugate();
......
...@@ -228,6 +228,36 @@ namespace Faust ...@@ -228,6 +228,36 @@ namespace Faust
return this->transform->replace(M, id); return this->transform->replace(M, id);
} }
template<typename FPP>
void TransformHelper<FPP, GPU2>::convertToSparse()
{
const MatDense<FPP,GPU2> * mat_dense;
const MatSparse<FPP,GPU2> * mat_sparse;
for(int i=0;i<this->size();i++)
{
if(mat_dense = dynamic_cast<const MatDense<FPP,GPU2>*>(this->get_gen_fact(i)))
{
mat_sparse = new MatSparse<FPP,GPU2>(*mat_dense);
this->replace(mat_sparse, i);
}
}
}
template<typename FPP>
void TransformHelper<FPP, GPU2>::convertToDense()
{
const MatDense<FPP,GPU2> * mat_dense;
const MatSparse<FPP,GPU2> * mat_sparse;
for(int i=0;i<this->size();i++)
{
if(mat_sparse = dynamic_cast<const MatSparse<FPP,GPU2>*>(this->get_gen_fact(i)))
{
mat_dense = new MatDense<FPP,GPU2>(*mat_sparse);
this->replace(mat_dense, i);
}
}
}
template<typename FPP> template<typename FPP>
TransformHelper<FPP,GPU2>* TransformHelper<FPP,GPU2>::multiply(const TransformHelper<FPP,GPU2>* right) TransformHelper<FPP,GPU2>* TransformHelper<FPP,GPU2>::multiply(const TransformHelper<FPP,GPU2>* right)
{ {
......
...@@ -96,6 +96,9 @@ namespace Faust ...@@ -96,6 +96,9 @@ namespace Faust
TransformHelper<FPP, DEV>* fancy_index(faust_unsigned_int* row_ids, faust_unsigned_int num_rows, faust_unsigned_int* col_ids, faust_unsigned_int num_cols); TransformHelper<FPP, DEV>* fancy_index(faust_unsigned_int* row_ids, faust_unsigned_int num_rows, faust_unsigned_int* col_ids, faust_unsigned_int num_cols);
virtual TransformHelper<FPP,DEV>* optimize_storage(const bool time=true); virtual TransformHelper<FPP,DEV>* optimize_storage(const bool time=true);
virtual TransformHelper<FPP,DEV>* clone(); virtual TransformHelper<FPP,DEV>* clone();
virtual void convertToSparse()=0;
virtual void convertToDense()=0;
protected: protected:
void init_fancy_idx_transform(TransformHelper<FPP,DEV>* th, faust_unsigned_int* row_ids, faust_unsigned_int num_rows, faust_unsigned_int* col_ids, faust_unsigned_int num_cols); void init_fancy_idx_transform(TransformHelper<FPP,DEV>* th, faust_unsigned_int* row_ids, faust_unsigned_int num_rows, faust_unsigned_int* col_ids, faust_unsigned_int num_cols);
void init_sliced_transform(TransformHelper<FPP,DEV>* th, Slice s[2]); void init_sliced_transform(TransformHelper<FPP,DEV>* th, Slice s[2]);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment