Mentions légales du service

Skip to content
Snippets Groups Projects
Commit f0eb6229 authored by hhakim's avatar hhakim
Browse files

Add a new signature of multiply_order_opt_all_ends to work on MatGeneric.

parent ee44fc8b
No related branches found
No related tags found
No related merge requests found
......@@ -38,6 +38,9 @@ namespace Faust
template<typename FPP, FDevice DEVICE>
void multiply_order_opt_all_ends(std::vector<MatDense<FPP,DEVICE>*>& facts, MatDense<FPP,DEVICE>& out, FPP alpha=1.0, FPP beta_out=.0, std::vector<char> transconj_flags = std::vector<char>({'N'}));
template<typename FPP, FDevice DEVICE>
void multiply_order_opt_all_ends(std::vector<MatGeneric<FPP,DEVICE>*>& facts, MatDense<FPP,DEVICE>& out, FPP alpha=1.0, FPP beta_out=.0, std::vector<char> transconj_flags= {'N'});
/**
*
* This function does the same as multiply_order_opt_all_ends but is capable to multiply not only on the ends of the matrix chain but also in the middle if a better complexity guides to this.
......
......@@ -44,6 +44,46 @@ void Faust::multiply_order_opt_all_ends(std::vector<Faust::MatDense<FPP,DEVICE>*
facts.erase(facts.begin(), facts.end());
}
template<typename FPP, FDevice DEVICE>
void Faust::multiply_order_opt_all_ends(std::vector<Faust::MatGeneric<FPP,DEVICE>*>& facts, Faust::MatDense<FPP,DEVICE>& out, FPP alpha/* =1.0*/, FPP beta_out/*=.0*/, std::vector<char> transconj_flags /* = {'N'}*/)
{
Faust::MatDense<FPP, DEVICE> tmpr, tmpl;
int nfacts = facts.size();
int ri = 0, li = nfacts-1;
Faust::MatGeneric<FPP,DEVICE> *R1, *R2, *L1, *L2;
faust_unsigned_int R1nr, R1nc, L1nr, L1nc;
while(li-ri > 1)
{
R1 = facts[ri];
R2 = facts[ri+1];
L1 = facts[li-1];
L2 = facts[li];
R1nr = R1->getNbRow();
R1nc = R1->getNbCol();
L1nr = L1->getNbRow();
L1nc = L1->getNbCol();
if(R1nr * R1nc * R2->getNbCol() < L1nr * L1nc * L2->getNbCol())
{
gemm_gen(*R1, *R2, tmpr, (FPP)1.0, (FPP)0.0, transconj_flags[transconj_flags.size()>ri?ri:0], transconj_flags[transconj_flags.size()>ri+1?ri+1:0]);
ri++;
facts[ri] = &tmpr;
if(transconj_flags.size() > ri)
transconj_flags[ri] = 'N';
}
else
{
gemm_gen(*L1, *L2, tmpl, (FPP)1.0, (FPP)0.0, transconj_flags[transconj_flags.size()>li-1?li-1:0], transconj_flags[transconj_flags.size()>li?li:0]);
li--;
facts[li] = &tmpl;
if(transconj_flags.size() > li)
transconj_flags[li] = 'N';
}
}
// last mul
gemm_gen(*facts[ri], *facts[li], out, alpha, beta_out, ri==0?transconj_flags[0]:'N', li==nfacts-1&&transconj_flags.size()>li?transconj_flags[li]:'N');
facts.erase(facts.begin(), facts.end());
}
template<typename FPP, FDevice DEVICE>
void Faust::multiply_order_opt_all_best(std::vector<Faust::MatDense<FPP,DEVICE>*>& facts, Faust::MatDense<FPP,DEVICE>& out, FPP alpha/* =1.0*/, FPP beta_out/*=.0*/, std::vector<char> transconj_flags /* = {'N'}*/)
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment