Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 695b587e authored by hhakim's avatar hhakim
Browse files

Make palm4msa2020 (impl1) capable to do Faust/matrices muls on GPU.

parent 156fed08
Branches
Tags
No related merge requests found
...@@ -19,6 +19,7 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -19,6 +19,7 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
Faust::MatGeneric<FPP,DEVICE>* cur_fac; Faust::MatGeneric<FPP,DEVICE>* cur_fac;
Faust::MatSparse<FPP,DEVICE>* scur_fac; Faust::MatSparse<FPP,DEVICE>* scur_fac;
Faust::MatDense<FPP,DEVICE>* dcur_fac; Faust::MatDense<FPP,DEVICE>* dcur_fac;
Faust::MatSparse<FPP,DEVICE> spD;
Real<FPP> error = -1; // negative error is ignored Real<FPP> error = -1; // negative error is ignored
const unsigned int nfacts = constraints.size(); const unsigned int nfacts = constraints.size();
std::vector<std::pair<faust_unsigned_int,faust_unsigned_int>> dims; std::vector<std::pair<faust_unsigned_int,faust_unsigned_int>> dims;
...@@ -27,32 +28,10 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -27,32 +28,10 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
dims.push_back(make_pair(c->get_rows(), c->get_cols())); dims.push_back(make_pair(c->get_rows(), c->get_cols()));
Faust::MatDense<FPP,DEVICE> A_H = A; Faust::MatDense<FPP,DEVICE> A_H = A;
A_H.adjoint(); A_H.adjoint();
// if(S.size() != nfacts)
// {
//// S = Faust::TransformHelper<FPP,DEVICE>();
// //TODO: refactor the id factor gen. into TransformHelper
// for(auto fdims : dims)
// {
// // init all facts as identity matrices
// // with proper dimensions
// Faust::MatGeneric<FPP,DEVICE>* fact;
// if(use_csr)
// {
// auto sfact = new Faust::MatSparse<FPP,DEVICE>(fdims.first, fdims.second);
// sfact->setEyes();
// fact = sfact;
// }
// else
// {
// auto dfact = new Faust::MatDense<FPP,DEVICE>(fdims.first, fdims.second);
// dfact->setEyes();
// fact = dfact;
// }
// S.push_back(fact); //TODO: copying=false
// }
// }
if(S.size() != nfacts) if(S.size() != nfacts)
fill_of_eyes(S, nfacts, use_csr, dims); fill_of_eyes(S, nfacts, use_csr, dims, on_gpu);
else
if(on_gpu) S.enable_gpu_meth_for_mul();
int i = 0, f_id; int i = 0, f_id;
std::function<void()> init_fid, next_fid; std::function<void()> init_fid, next_fid;
std::function<bool()> updating_facs; std::function<bool()> updating_facs;
...@@ -73,7 +52,7 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -73,7 +52,7 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
Faust::TransformHelper<FPP, Cpu> LSR; Faust::TransformHelper<FPP, Cpu> LSR;
Real<FPP> c = 1/step_size; Real<FPP> c = 1/step_size;
// lambda exp to update fact when its id is 0 // lambda exp to update fact when its id is 0
auto update_1stfac = [&sc, &error, &constant_step_size, &c, &A, &D, &tmp, &LSR, &scur_fac, &dcur_fac, &f_id, &S, &lipschitz_multiplicator, &lambda, &norm2_threshold, &norm2_flag, &norm2_max_iter](Faust::MatGeneric<FPP, DEVICE> *cur_fac) auto update_1stfac = [&sc, &error, &constant_step_size, &c, &A, &D, &tmp, &LSR, &scur_fac, &dcur_fac, &f_id, &S, &lipschitz_multiplicator, &lambda, &norm2_threshold, &norm2_flag, &norm2_max_iter, &on_gpu](Faust::MatGeneric<FPP, DEVICE> *cur_fac)
{ {
auto R = S.right(f_id+1); auto R = S.right(f_id+1);
if(! constant_step_size) if(! constant_step_size)
...@@ -104,6 +83,7 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -104,6 +83,7 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
// _LSR.multiply(lambda); // _LSR.multiply(lambda);
// tmp = _LSR.get_product(); // tmp = _LSR.get_product();
_LSR.get_product(tmp); _LSR.get_product(tmp);
if(on_gpu) assert(10 == _LSR.get_mul_order_opt_mode());
tmp *= lambda; tmp *= lambda;
tmp -= A; tmp -= A;
if(sc.isCriterionErr()) if(sc.isCriterionErr())
...@@ -111,11 +91,12 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -111,11 +91,12 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
//TODO: do something to lighten the double transpose conjugate //TODO: do something to lighten the double transpose conjugate
tmp.adjoint(); tmp.adjoint();
tmp = R->multiply(tmp, /* H */ false, false); tmp = R->multiply(tmp, /* H */ false, false);
if(on_gpu) assert(10 == R->get_mul_order_opt_mode());
tmp.adjoint(); tmp.adjoint();
tmp *= lambda/c; tmp *= lambda/c;
D -= tmp; D -= tmp;
}; };
auto update_lastfac = [&sc, &error, &constant_step_size, &c, &A, &D, &tmp, &LSR, &scur_fac, &dcur_fac, &f_id, &S, &lipschitz_multiplicator, &lambda, &norm2_threshold, &norm2_flag, &norm2_max_iter](Faust::MatGeneric<FPP, DEVICE> *cur_fac) auto update_lastfac = [&sc, &error, &constant_step_size, &c, &A, &D, &tmp, &LSR, &scur_fac, &dcur_fac, &f_id, &S, &lipschitz_multiplicator, &lambda, &norm2_threshold, &norm2_flag, &norm2_max_iter, &on_gpu](Faust::MatGeneric<FPP, DEVICE> *cur_fac)
{ {
auto L = S.left(f_id-1); auto L = S.left(f_id-1);
// TODO: factorize with other lambda exp // TODO: factorize with other lambda exp
...@@ -146,15 +127,17 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -146,15 +127,17 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
// LSR = _LSR; // LSR = _LSR;
// _LSR.multiply(lambda); // _LSR.multiply(lambda);
tmp = _LSR.get_product(); tmp = _LSR.get_product();
if(on_gpu) assert(10 == _LSR.get_mul_order_opt_mode());
tmp *= lambda; tmp *= lambda;
tmp -= A; tmp -= A;
if(sc.isCriterionErr()) if(sc.isCriterionErr())
error = tmp.norm(); error = tmp.norm();
tmp = L->multiply(tmp, /* NO H */ true, true); tmp = L->multiply(tmp, /* NO H */ true, true);
if(on_gpu) assert(10 == L->get_mul_order_opt_mode());
tmp *= lambda/c; tmp *= lambda/c;
D -= tmp; D -= tmp;
}; };
auto update_interfac = [&sc, &error, &constant_step_size, &c, &A, &D, &tmp, &LSR, &scur_fac, &dcur_fac, &f_id, &S, &lipschitz_multiplicator, &lambda, &norm2_threshold, &norm2_flag, &norm2_max_iter](Faust::MatGeneric<FPP, DEVICE> *cur_fac) auto update_interfac = [&sc, &error, &constant_step_size, &c, &A, &D, &tmp, &LSR, &scur_fac, &dcur_fac, &f_id, &S, &lipschitz_multiplicator, &lambda, &norm2_threshold, &norm2_flag, &norm2_max_iter, &on_gpu](Faust::MatGeneric<FPP, DEVICE> *cur_fac)
{ {
auto R = S.right(f_id+1); auto R = S.right(f_id+1);
auto L = S.left(f_id-1); auto L = S.left(f_id-1);
...@@ -180,6 +163,7 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -180,6 +163,7 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
// LSR = _LSR; // LSR = _LSR;
// _LSR.multiply(lambda); // _LSR.multiply(lambda);
tmp = _LSR.get_product(); tmp = _LSR.get_product();
if(on_gpu) assert(10 == _LSR.get_mul_order_opt_mode());
tmp *= lambda; tmp *= lambda;
tmp -= A; tmp -= A;
if(sc.isCriterionErr()) if(sc.isCriterionErr())
...@@ -187,8 +171,10 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -187,8 +171,10 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
//TODO: do something to lighten the double transpose conjugate //TODO: do something to lighten the double transpose conjugate
tmp.adjoint(); tmp.adjoint();
tmp = R->multiply(tmp, /* NO H */ false, false); tmp = R->multiply(tmp, /* NO H */ false, false);
if(on_gpu) assert(10 == R->get_mul_order_opt_mode());
tmp.adjoint(); tmp.adjoint();
tmp = L->multiply(tmp, true, true); tmp = L->multiply(tmp, true, true);
if(on_gpu) assert(10 == L->get_mul_order_opt_mode());
tmp *= lambda/c; tmp *= lambda/c;
D -= tmp; D -= tmp;
}; };
...@@ -201,7 +187,9 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -201,7 +187,9 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
// std::cout << "f_id: " << f_id << std::endl; // std::cout << "f_id: " << f_id << std::endl;
cur_fac = S.get_gen_fact_nonconst(f_id); cur_fac = S.get_gen_fact_nonconst(f_id);
if(f_id == 0) if(f_id == 0)
{
update_1stfac(cur_fac); update_1stfac(cur_fac);
}
else if(f_id == nfacts-1) else if(f_id == nfacts-1)
update_lastfac(cur_fac); update_lastfac(cur_fac);
else else
...@@ -209,20 +197,16 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -209,20 +197,16 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
// really update now // really update now
constraints[f_id]->project<FPP,DEVICE,Real<FPP>>(D); constraints[f_id]->project<FPP,DEVICE,Real<FPP>>(D);
if(! use_csr && scur_fac == nullptr)
{ if(use_csr && dcur_fac != nullptr || !use_csr && scur_fac != nullptr)
// not CSR and cur_fac DENSE throw std::runtime_error("Current factor is inconsistent with use_csr.");
*dcur_fac = D; if(use_csr)
}
else if(use_csr && dcur_fac == nullptr)
{ {
// CSR and cur_fac SPARSE spD = D;
*scur_fac = D; S.update(spD, f_id); // update is at higher level than a simple assignment
} }
else else
throw std::runtime_error("Current factor is inconsistent with use_csr."); S.update(D, f_id);
cur_fac->set_id(false);
S.update_total_nnz();
next_fid(); //f_id updated to iteration factor index next_fid(); //f_id updated to iteration factor index
} }
//update lambda //update lambda
...@@ -267,5 +251,5 @@ void Faust::fill_of_eyes( ...@@ -267,5 +251,5 @@ void Faust::fill_of_eyes(
} }
S.push_back(fact); //TODO: copying=false S.push_back(fact); //TODO: copying=false
} }
if(on_gpu) S.set_mul_order_opt_mode(10); if(on_gpu) S.enable_gpu_meth_for_mul();
} }
...@@ -34,6 +34,8 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A, ...@@ -34,6 +34,8 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
A_H.adjoint(); A_H.adjoint();
if(S.size() != nfacts) if(S.size() != nfacts)
fill_of_eyes(S, nfacts, use_csr, dims, on_gpu); fill_of_eyes(S, nfacts, use_csr, dims, on_gpu);
else if(on_gpu)
S.enable_gpu_meth_for_mul();
int i = 0, f_id; int i = 0, f_id;
std::function<void()> init_ite, next_fid; std::function<void()> init_ite, next_fid;
std::function<bool()> updating_facs; std::function<bool()> updating_facs;
......
...@@ -102,7 +102,7 @@ namespace Faust ...@@ -102,7 +102,7 @@ namespace Faust
// release all gpu mats // release all gpu mats
for(auto m: cpu_mat_ptrs) for(auto m: cpu_mat_ptrs)
{ {
if(ref_man.contains(m)) //TODO: remove when the bug of nonsense factor is corrected // if(ref_man.contains(m)) // useless assuming we release only the acquired factors
ref_man.release(m); ref_man.release(m);
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment