Mentions légales du service

Skip to content
Snippets Groups Projects
Commit f02508e5 authored by hhakim's avatar hhakim
Browse files

Fix PALM4MSA2020 gradient computation in complex case: replace transpose ops...

Fix PALM4MSA2020 gradient computation in complex case: replace transpose ops by transconjugates, split gradient computation and applying to the factor (even if the all-in-once operation was correct it fails the DFT factorization with a large error).
parent 48b7a96c
No related branches found
No related tags found
No related merge requests found
......@@ -80,7 +80,6 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
auto vec_Si_minus_1 = { *(S.begin()+i-1) };
if(pL[i] != nullptr) delete pL[i]; //TODO: maybe to replace by a TransformHelper stored in the stack to avoid deleting each time
pL[i] = new TransformHelper<FPP,DEVICE>(*pL[i-1], vec_Si_minus_1);
// if the ctor args are GPU-enabled so is pL[i]
if(packing_RL) ((TransformHelperGen<FPP,DEVICE>*)pL[i])->pack_factors(prod_mod);
}
// all pL[i] Fausts are composed at most of one factor matrix
......@@ -171,6 +170,7 @@ void Faust::palm4msa2(const Faust::MatDense<FPP,DEVICE>& A,
auto err = calc_rel_err(S, A, lambda);
std::cout << " relative error: " << err;
std::cout << " (call id: " << id << ")" << std::endl;
std::cout << " lambda=" << lambda << std::endl;
}
}
i++;
......@@ -222,10 +222,10 @@ void Faust::compute_n_apply_grad1(const int f_id, const Faust::MatDense<FPP,DEVI
{ //no L factor for factor f_id
alpha_R = - lambda/c;
beta_R = 1;
gemm(tmp, *LorR, D, alpha_R, beta_R, 'N', 'T');
gemm(tmp, *LorR, D, alpha_R, beta_R, 'N', 'H');
}
else
gemm(tmp, *LorR, tmp, alpha_R, beta_R, 'N', 'T');
gemm(tmp, *LorR, tmp, alpha_R, beta_R, 'N', 'H');
}
if(pL_sz > 0)
{
......@@ -238,7 +238,7 @@ void Faust::compute_n_apply_grad1(const int f_id, const Faust::MatDense<FPP,DEVI
}
alpha_L = -lambda/c;
beta_L = 1;
gemm(*LorR, tmp, D, alpha_L, beta_L, 'T', 'N');
gemm(*LorR, tmp, D, alpha_L, beta_L, 'H', 'N');
}
}
......@@ -246,6 +246,7 @@ template <typename FPP, FDevice DEVICE>
void Faust::compute_n_apply_grad2(const int f_id, const Faust::MatDense<FPP,DEVICE> &A, Faust::TransformHelper<FPP,DEVICE>& S, std::vector<TransformHelper<FPP,DEVICE>*> &pL, std::vector<TransformHelper<FPP,DEVICE>*>& pR, const Real<FPP>& lambda, const Real<FPP> &c, Faust::MatDense<FPP,DEVICE> &out /* D */, const StoppingCriterion<Real<FPP>>& sc, Real<FPP> &error, const int prod_mod)
{
Faust::MatDense<FPP,DEVICE> tmp;
Faust::MatDense<FPP,DEVICE> grad_over_c;
Faust::MatDense<FPP,DEVICE> & D = out;
Faust::MatDense<FPP,DEVICE> *_L, *_R, __L, __R;
Faust::MatDense<FPP,DEVICE> * LorR;
......@@ -287,8 +288,7 @@ void Faust::compute_n_apply_grad2(const int f_id, const Faust::MatDense<FPP,DEVI
error = tmp.norm();
// compute m_lambda/c * L'*error*R'
facts = { _L, &tmp, _R };
tc_flags = {'T', 'N', 'T'};
mul_3_facts(facts, D, (FPP) - lambda/c, (FPP)1, tc_flags);
tc_flags = {'H', 'N', 'H'};
}
else if(pR_sz > 0)
{
......@@ -299,8 +299,7 @@ void Faust::compute_n_apply_grad2(const int f_id, const Faust::MatDense<FPP,DEVI
error = tmp.norm();
// compute m_lambda/c * L'*error*R'
facts = { &tmp, _R };
tc_flags = { 'N', 'T'};
mul_3_facts(facts, D, (FPP) - lambda/c, (FPP)1, tc_flags);
tc_flags = { 'N', 'H'};
}
else //if(pL_sz > 0)
{
......@@ -311,9 +310,12 @@ void Faust::compute_n_apply_grad2(const int f_id, const Faust::MatDense<FPP,DEVI
error = tmp.norm();
// compute m_lambda/c * L'*error*R'
facts = { _L, &tmp};
tc_flags = {'T', 'N'};
mul_3_facts(facts, D, (FPP) - lambda/c, (FPP)1, tc_flags);
tc_flags = {'H', 'N'};
}
// mul_3_facts(facts, D, (FPP) - lambda/c, (FPP)1, tc_flags); // this one has showed error in calculation when using complex matrices (DFT factorization)
// do it two steps: 1) compute the gradient, then 2) apply it separately to D (the current factor)
mul_3_facts(facts, grad_over_c, (FPP) lambda/c, (FPP)0, tc_flags);
D -= grad_over_c;
}
template<typename FPP, FDevice DEVICE>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment