Mentions légales du service

Skip to content
Snippets Groups Projects
Commit d41cc07e authored by hhakim's avatar hhakim
Browse files

Implement a C++ optimization of N matrices by choosing the best order of...

Implement a C++ optimization of N matrices by choosing the best order of multiplications while cb82a735 optimized only for 3 factors.
parent 2d889ec7
No related branches found
No related tags found
No related merge requests found
...@@ -200,6 +200,8 @@ t_local_compute_projection.stop(); ...@@ -200,6 +200,8 @@ t_local_compute_projection.stop();
template<typename FPP,Device DEVICE,typename FPP2> template<typename FPP,Device DEVICE,typename FPP2>
void Faust::Palm4MSA<FPP,DEVICE,FPP2>::compute_grad_over_c_ext_opt() void Faust::Palm4MSA<FPP,DEVICE,FPP2>::compute_grad_over_c_ext_opt()
{ {
//#define mul_3_facts multiply_order_opt
#define mul_3_facts multiply_order_opt_ext // this one only optimize the product on factor ends but for three factors it doesn't change anything comparing to multiply_order_opt
// compute error = m_lambda*L*S*R-data // compute error = m_lambda*L*S*R-data
error = data; error = data;
std::vector<MatDense<FPP,DEVICE>*> facts; std::vector<MatDense<FPP,DEVICE>*> facts;
...@@ -208,14 +210,14 @@ void Faust::Palm4MSA<FPP,DEVICE,FPP2>::compute_grad_over_c_ext_opt() ...@@ -208,14 +210,14 @@ void Faust::Palm4MSA<FPP,DEVICE,FPP2>::compute_grad_over_c_ext_opt()
facts = { &RorL[m_indFact], &S[m_indFact], &LorR }; facts = { &RorL[m_indFact], &S[m_indFact], &LorR };
else else
facts = { &LorR, &S[m_indFact], &RorL[m_indFact] }; facts = { &LorR, &S[m_indFact], &RorL[m_indFact] };
multiply_order_opt(facts, error, (FPP) m_lambda, (FPP) -1.0); mul_3_facts(facts, error, (FPP) m_lambda, (FPP) -1.0);
// compute m_lambda/c * L'*error*R' // compute m_lambda/c * L'*error*R'
if(isUpdateWayR2L) if(isUpdateWayR2L)
facts = { &RorL[m_indFact], &error, &LorR }; facts = { &RorL[m_indFact], &error, &LorR };
else else
facts = {&LorR, &error, &RorL[m_indFact]}; facts = {&LorR, &error, &RorL[m_indFact]};
tc_flags = {TorH, 'N', TorH}; tc_flags = {TorH, 'N', TorH};
multiply_order_opt(facts, grad_over_c, (FPP) m_lambda/c, (FPP)0, tc_flags); mul_3_facts(facts, grad_over_c, (FPP) m_lambda/c, (FPP)0, tc_flags);
isGradComputed = true; isGradComputed = true;
} }
......
...@@ -167,7 +167,9 @@ namespace Faust ...@@ -167,7 +167,9 @@ namespace Faust
template<typename FPP> template<typename FPP>
FPP fabs(complex<FPP> c); FPP fabs(complex<FPP> c);
template<typename FPP, Device DEVICE> template<typename FPP, Device DEVICE>
void multiply_order_opt(std::vector<Faust::MatDense<FPP,DEVICE>*>& facts, Faust::MatDense<FPP,DEVICE>& out, FPP alpha=1.0, FPP beta_out=.0, std::vector<char> transconj_flags = std::vector<char>({'N'})); void multiply_order_opt_ext(std::vector<Faust::MatDense<FPP,DEVICE>*>& facts, Faust::MatDense<FPP,DEVICE>& out, FPP alpha=1.0, FPP beta_out=.0, std::vector<char> transconj_flags = std::vector<char>({'N'}));
template<typename FPP, Device DEVICE>
void multiply_order_opt(std::vector<Faust::MatDense<FPP,DEVICE>*>& facts, Faust::MatDense<FPP,DEVICE>& out, FPP alpha=1.0, FPP beta_out=.0, std::vector<char> transconj_flags = std::vector<char>({'N'}));
} }
......
...@@ -733,7 +733,7 @@ FPP Faust::fabs(complex<FPP> c) ...@@ -733,7 +733,7 @@ FPP Faust::fabs(complex<FPP> c)
* \note: the vector facts is altered after the call! Don't reuse it. * \note: the vector facts is altered after the call! Don't reuse it.
*/ */
template<typename FPP, Device DEVICE> template<typename FPP, Device DEVICE>
void Faust::multiply_order_opt(std::vector<Faust::MatDense<FPP,DEVICE>*>& facts, Faust::MatDense<FPP,DEVICE>& out, FPP alpha/* =1.0*/, FPP beta_out/*=.0*/, std::vector<char> transconj_flags /* = {'N'}*/) void Faust::multiply_order_opt_ext(std::vector<Faust::MatDense<FPP,DEVICE>*>& facts, Faust::MatDense<FPP,DEVICE>& out, FPP alpha/* =1.0*/, FPP beta_out/*=.0*/, std::vector<char> transconj_flags /* = {'N'}*/)
{ {
Faust::MatDense<FPP, DEVICE> tmpr, tmpl; Faust::MatDense<FPP, DEVICE> tmpr, tmpl;
int nfacts = facts.size(); int nfacts = facts.size();
...@@ -750,7 +750,6 @@ void Faust::multiply_order_opt(std::vector<Faust::MatDense<FPP,DEVICE>*>& facts, ...@@ -750,7 +750,6 @@ void Faust::multiply_order_opt(std::vector<Faust::MatDense<FPP,DEVICE>*>& facts,
R1nc = R1->getNbCol(); R1nc = R1->getNbCol();
L1nr = L1->getNbRow(); L1nr = L1->getNbRow();
L1nc = L1->getNbCol(); L1nc = L1->getNbCol();
//TODO: allow matrix mul. of two factors in the middle with a deletion of the right factor afterward
if(R1nr * R1nc * R2->getNbCol() < L1nr * L1nc * L2->getNbCol()) if(R1nr * R1nc * R2->getNbCol() < L1nr * L1nc * L2->getNbCol())
{ {
gemm(*R1, *R2, tmpr, (FPP)1.0, (FPP)0.0, transconj_flags[transconj_flags.size()>ri?ri:0], transconj_flags[transconj_flags.size()>ri+1?ri+1:0]); gemm(*R1, *R2, tmpr, (FPP)1.0, (FPP)0.0, transconj_flags[transconj_flags.size()>ri?ri:0], transconj_flags[transconj_flags.size()>ri+1?ri+1:0]);
...@@ -770,7 +769,74 @@ void Faust::multiply_order_opt(std::vector<Faust::MatDense<FPP,DEVICE>*>& facts, ...@@ -770,7 +769,74 @@ void Faust::multiply_order_opt(std::vector<Faust::MatDense<FPP,DEVICE>*>& facts,
} }
// last mul // last mul
gemm(*facts[ri], *facts[li], out, alpha, beta_out, ri==0?transconj_flags[0]:'N', li==nfacts-1&&transconj_flags.size()>li?transconj_flags[li]:'N'); gemm(*facts[ri], *facts[li], out, alpha, beta_out, ri==0?transconj_flags[0]:'N', li==nfacts-1&&transconj_flags.size()>li?transconj_flags[li]:'N');
facts.erase(facts.begin()); facts.erase(facts.begin(), facts.end());
} }
template<typename FPP, Device DEVICE>
void Faust::multiply_order_opt(std::vector<Faust::MatDense<FPP,DEVICE>*>& facts, Faust::MatDense<FPP,DEVICE>& out, FPP alpha/* =1.0*/, FPP beta_out/*=.0*/, std::vector<char> transconj_flags /* = {'N'}*/)
{
std::vector<Faust::MatDense<FPP,DEVICE>*> tmp_facts; //temporary product results
Faust::MatDense<FPP, DEVICE>* tmp;
int nfacts = facts.size();
Faust::MatDense<FPP,DEVICE> *Si, *Sj;
vector<int> complexity(nfacts-1);
for(int i = 0; i <nfacts-2; i ++)
{
Si = facts[i];
Sj = facts[i+1];
complexity[i] = Si->getNbRow() * Si->getNbCol() * Sj->getNbCol();
}
int idx; // marks the factor to update with a product of contiguous factors
bool multiplying_tmp_factor = false; // allows to avoid to allocate uselessly a tmp factor if Si or Sj are already tmp factors
while(facts.size() > 2)
{
// find the least complex product facts[idx]*facts[idx+1]
idx = distance(complexity.begin(), min_element(complexity.begin(), complexity.end()));
Si = facts[idx];
Sj = facts[idx+1];
for(auto Tit = tmp_facts.begin(); Tit != tmp_facts.end(); Tit++)
{
if(Sj == *Tit)
{// Sj is original fact
multiplying_tmp_factor = true;
tmp = Sj;
break;
}
else if(Si == *Tit)
{
multiplying_tmp_factor = true;
tmp = Si;
break;
}
}
if(! multiplying_tmp_factor)
{
tmp = new Faust::MatDense<FPP, DEVICE>();
tmp_facts.push_back(tmp);
}
//else no need to instantiate a new tmp, erasing Sj which is a tmp
gemm(*Si, *Sj, *tmp, (FPP)1.0, (FPP)0.0, transconj_flags[transconj_flags.size()>idx?idx:0], transconj_flags[transconj_flags.size()>idx+1?idx+1:0]);
facts.erase(facts.begin()+idx+1);
facts[idx] = tmp;
if(transconj_flags.size() > idx)
transconj_flags[idx] = 'N';
// update complexity around the new factor
if(facts.size() > 2)
{
if(idx > 0)
complexity[idx-1] = facts[idx-1]->getNbRow() * facts[idx-1]->getNbCol() * facts[idx]->getNbCol();
if(idx < facts.size()-1)
complexity[idx] = facts[idx]->getNbRow() * facts[idx]->getNbCol() * facts[idx+1]->getNbCol();
}
multiplying_tmp_factor = false;
}
// last mul
gemm(*facts[0], *facts[1], out, alpha, beta_out, transconj_flags[0], transconj_flags.size()>1?transconj_flags[1]:'N');
facts.erase(facts.begin(), facts.end());
// delete all tmp facts
for(auto Tit = tmp_facts.begin(); Tit != tmp_facts.end(); Tit++)
{
delete *Tit;
}
}
#endif #endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment