Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 695b587e authored by hhakim's avatar hhakim
Browse files

Make palm4msa2020 (impl1) capable to do Faust/matrices muls on GPU.

parent 156fed08
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,7 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
Faust::MatGeneric<FPP,DEVICE>* cur_fac;
Faust::MatSparse<FPP,DEVICE>* scur_fac;
Faust::MatDense<FPP,DEVICE>* dcur_fac;
Faust::MatSparse<FPP,DEVICE> spD;
Real<FPP> error = -1; // negative error is ignored
const unsigned int nfacts = constraints.size();
std::vector<std::pair<faust_unsigned_int,faust_unsigned_int>> dims;
......@@ -27,32 +28,10 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
dims.push_back(make_pair(c->get_rows(), c->get_cols()));
Faust::MatDense<FPP,DEVICE> A_H = A;
A_H.adjoint();
// if(S.size() != nfacts)
// {
//// S = Faust::TransformHelper<FPP,DEVICE>();
// //TODO: refactor the id factor gen. into TransformHelper
// for(auto fdims : dims)
// {
// // init all facts as identity matrices
// // with proper dimensions
// Faust::MatGeneric<FPP,DEVICE>* fact;
// if(use_csr)
// {
// auto sfact = new Faust::MatSparse<FPP,DEVICE>(fdims.first, fdims.second);
// sfact->setEyes();
// fact = sfact;
// }
// else
// {
// auto dfact = new Faust::MatDense<FPP,DEVICE>(fdims.first, fdims.second);
// dfact->setEyes();
// fact = dfact;
// }
// S.push_back(fact); //TODO: copying=false
// }
// }
if(S.size() != nfacts)
fill_of_eyes(S, nfacts, use_csr, dims);
fill_of_eyes(S, nfacts, use_csr, dims, on_gpu);
else
if(on_gpu) S.enable_gpu_meth_for_mul();
int i = 0, f_id;
std::function<void()> init_fid, next_fid;
std::function<bool()> updating_facs;
......@@ -73,7 +52,7 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
Faust::TransformHelper<FPP, Cpu> LSR;
Real<FPP> c = 1/step_size;
// lambda exp to update fact when its id is 0
auto update_1stfac = [&sc, &error, &constant_step_size, &c, &A, &D, &tmp, &LSR, &scur_fac, &dcur_fac, &f_id, &S, &lipschitz_multiplicator, &lambda, &norm2_threshold, &norm2_flag, &norm2_max_iter](Faust::MatGeneric<FPP, DEVICE> *cur_fac)
auto update_1stfac = [&sc, &error, &constant_step_size, &c, &A, &D, &tmp, &LSR, &scur_fac, &dcur_fac, &f_id, &S, &lipschitz_multiplicator, &lambda, &norm2_threshold, &norm2_flag, &norm2_max_iter, &on_gpu](Faust::MatGeneric<FPP, DEVICE> *cur_fac)
{
auto R = S.right(f_id+1);
if(! constant_step_size)
......@@ -104,6 +83,7 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
// _LSR.multiply(lambda);
// tmp = _LSR.get_product();
_LSR.get_product(tmp);
if(on_gpu) assert(10 == _LSR.get_mul_order_opt_mode());
tmp *= lambda;
tmp -= A;
if(sc.isCriterionErr())
......@@ -111,11 +91,12 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
//TODO: do something to lighten the double transpose conjugate
tmp.adjoint();
tmp = R->multiply(tmp, /* H */ false, false);
if(on_gpu) assert(10 == R->get_mul_order_opt_mode());
tmp.adjoint();
tmp *= lambda/c;
D -= tmp;
};
auto update_lastfac = [&sc, &error, &constant_step_size, &c, &A, &D, &tmp, &LSR, &scur_fac, &dcur_fac, &f_id, &S, &lipschitz_multiplicator, &lambda, &norm2_threshold, &norm2_flag, &norm2_max_iter](Faust::MatGeneric<FPP, DEVICE> *cur_fac)
auto update_lastfac = [&sc, &error, &constant_step_size, &c, &A, &D, &tmp, &LSR, &scur_fac, &dcur_fac, &f_id, &S, &lipschitz_multiplicator, &lambda, &norm2_threshold, &norm2_flag, &norm2_max_iter, &on_gpu](Faust::MatGeneric<FPP, DEVICE> *cur_fac)
{
auto L = S.left(f_id-1);
// TODO: factorize with other lambda exp
......@@ -146,15 +127,17 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
// LSR = _LSR;
// _LSR.multiply(lambda);
tmp = _LSR.get_product();
if(on_gpu) assert(10 == _LSR.get_mul_order_opt_mode());
tmp *= lambda;
tmp -= A;
if(sc.isCriterionErr())
error = tmp.norm();
tmp = L->multiply(tmp, /* NO H */ true, true);
if(on_gpu) assert(10 == L->get_mul_order_opt_mode());
tmp *= lambda/c;
D -= tmp;
};
auto update_interfac = [&sc, &error, &constant_step_size, &c, &A, &D, &tmp, &LSR, &scur_fac, &dcur_fac, &f_id, &S, &lipschitz_multiplicator, &lambda, &norm2_threshold, &norm2_flag, &norm2_max_iter](Faust::MatGeneric<FPP, DEVICE> *cur_fac)
auto update_interfac = [&sc, &error, &constant_step_size, &c, &A, &D, &tmp, &LSR, &scur_fac, &dcur_fac, &f_id, &S, &lipschitz_multiplicator, &lambda, &norm2_threshold, &norm2_flag, &norm2_max_iter, &on_gpu](Faust::MatGeneric<FPP, DEVICE> *cur_fac)
{
auto R = S.right(f_id+1);
auto L = S.left(f_id-1);
......@@ -180,6 +163,7 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
// LSR = _LSR;
// _LSR.multiply(lambda);
tmp = _LSR.get_product();
if(on_gpu) assert(10 == _LSR.get_mul_order_opt_mode());
tmp *= lambda;
tmp -= A;
if(sc.isCriterionErr())
......@@ -187,8 +171,10 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
//TODO: do something to lighten the double transpose conjugate
tmp.adjoint();
tmp = R->multiply(tmp, /* NO H */ false, false);
if(on_gpu) assert(10 == R->get_mul_order_opt_mode());
tmp.adjoint();
tmp = L->multiply(tmp, true, true);
if(on_gpu) assert(10 == L->get_mul_order_opt_mode());
tmp *= lambda/c;
D -= tmp;
};
......@@ -201,7 +187,9 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
// std::cout << "f_id: " << f_id << std::endl;
cur_fac = S.get_gen_fact_nonconst(f_id);
if(f_id == 0)
{
update_1stfac(cur_fac);
}
else if(f_id == nfacts-1)
update_lastfac(cur_fac);
else
......@@ -209,20 +197,16 @@ void Faust::palm4msa(const Faust::MatDense<FPP,DEVICE>& A,
// really update now
constraints[f_id]->project<FPP,DEVICE,Real<FPP>>(D);
if(! use_csr && scur_fac == nullptr)
{
// not CSR and cur_fac DENSE
*dcur_fac = D;
}
else if(use_csr && dcur_fac == nullptr)
if(use_csr && dcur_fac != nullptr || !use_csr && scur_fac != nullptr)
throw std::runtime_error("Current factor is inconsistent with use_csr.");
if(use_csr)
{
// CSR and cur_fac SPARSE
*scur_fac = D;
spD = D;
S.update(spD, f_id); // update is at higher level than a simple assignment
}
else
throw std::runtime_error("Current factor is inconsistent with use_csr.");
cur_fac->set_id(false);
S.update_total_nnz();
S.update(D, f_id);
next_fid(); //f_id updated to iteration factor index
}
//update lambda
......@@ -267,5 +251,5 @@ void Faust::fill_of_eyes(
}
S.push_back(fact); //TODO: copying=false
}
if(on_gpu) S.set_mul_order_opt_mode(10);
if(on_gpu) S.enable_gpu_meth_for_mul();
}
......@@ -34,6 +34,8 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
A_H.adjoint();
if(S.size() != nfacts)
fill_of_eyes(S, nfacts, use_csr, dims, on_gpu);
else if(on_gpu)
S.enable_gpu_meth_for_mul();
int i = 0, f_id;
std::function<void()> init_ite, next_fid;
std::function<bool()> updating_facs;
......
......@@ -102,7 +102,7 @@ namespace Faust
// release all gpu mats
for(auto m: cpu_mat_ptrs)
{
if(ref_man.contains(m)) //TODO: remove when the bug of nonsense factor is corrected
// if(ref_man.contains(m)) // useless assuming we release only the acquired factors
ref_man.release(m);
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment