Mentions légales du service

Skip to content
Snippets Groups Projects
Commit ec0b97bf authored by hhakim's avatar hhakim
Browse files

Update gemm_gen to handle MatBSR too.

parent ce8e8d2a
No related branches found
No related tags found
No related merge requests found
......@@ -30,9 +30,13 @@ template<typename T,int BlockStorageOrder=0> class BSRMat;
namespace Faust
{
// forward decl. (faust_linear_algebra.h) for friendship
template<typename FPP> void gemm_gen(const MatGeneric<FPP, Cpu>& A, const MatGeneric<FPP, Cpu>& B, MatDense<FPP, Cpu>& out, const FPP alpha/*=FPP(1.0)*/, const FPP beta/*=(0.0)*/, const char opA/*='N'*/, const char opB/*='N'*/);
template<typename FPP>
class MatBSR<FPP,Cpu> : public MatGeneric<FPP,Cpu>
{
friend void gemm_gen<>(const MatGeneric<FPP, Cpu>& A, const MatGeneric<FPP, Cpu>& B, MatDense<FPP, Cpu>& out, const FPP alpha/*=FPP(1.0)*/, const FPP beta/*=(0.0)*/, const char opA/*='N'*/, const char opB/*='N'*/);
BSRMat<FPP> bmat; // low-level BSRMat
MatBSR() : MatGeneric<FPP, Cpu>() {}
MatBSR(BSRMat<FPP>& mat);
......
......@@ -506,11 +506,16 @@ void Faust::Transform<FPP,Cpu>::get_product(Faust::MatDense<FPP,Cpu> &mat, const
{
if(data.size() == 1)
{
auto end = this->size()-1;
// just one matrix in the Faust, return a copy as dense matrix
if(dynamic_cast<MatSparse<FPP,Cpu>*>(data[this->size()-1]))
mat = *dynamic_cast<MatSparse<FPP,Cpu>*>(data[this->size()-1]);
else if(dynamic_cast<MatDense<FPP,Cpu>*>(data[this->size()-1]))
mat = *dynamic_cast<MatDense<FPP,Cpu>*>(data[this->size()-1]);
if(dynamic_cast<MatSparse<FPP,Cpu>*>(data[end]))
mat = *dynamic_cast<MatSparse<FPP,Cpu>*>(data[end]);
else if(dynamic_cast<MatDense<FPP,Cpu>*>(data[end]))
mat = *dynamic_cast<MatDense<FPP,Cpu>*>(data[end]);
else if(dynamic_cast<MatBSR<FPP, Cpu>*>(data[end]))
{
mat = dynamic_cast<MatBSR<FPP, Cpu>*>(data[end])->to_dense();
}
if(isConj)
mat.conjugate();
return;
......@@ -519,7 +524,7 @@ void Faust::Transform<FPP,Cpu>::get_product(Faust::MatDense<FPP,Cpu> &mat, const
{
// at least two factors, compute the first product (of the last two matrices)
// it avoids making a copy of the last factor
gemm_gen(*data[this->size()-2], *data[this->size()-1], mat);
gemm_gen(*data[this->size()-2], *data[this->size()-1], mat, FPP(1.0), FPP(0.0), 'N', 'N');
}
for (int i=this->size()-3; i >= 0; i--)
{
......@@ -1063,12 +1068,11 @@ void Faust::Transform<FPP,Cpu>::get_fact(const faust_unsigned_int id,
s_outer_count_ptr = tmat.getRowPtr();
s_inner_ptr = tmat.getColInd();
s_elts = tmat.getValuePtr();
//do the copy here, otherwise we'll lose tmat and its buffers when out of scope
//do the copy here, otherwise we'll lose tmp mat and its buffers when out of scope
memcpy(d_outer_count_ptr, s_outer_count_ptr, sizeof(int)*(*num_cols+1));
memcpy(d_inner_ptr, s_inner_ptr, sizeof(int)**nnz);
memcpy(d_elts, s_elts, sizeof(FPP)**nnz);
// swap num_cols and num_rows
// (with only these 2 variables -- F2 arithmetic trick)
*num_cols = *num_cols^*num_rows;
*num_rows = *num_cols^*num_rows;
*num_cols = *num_cols^*num_rows;
......
......@@ -67,7 +67,7 @@ namespace Faust
// Computes alpha*typeA(A)*typeB(B)+ beta*C into C.
template<typename FPP>
void gemm_gen(const Faust::MatGeneric<FPP,Cpu> & A,const Faust::MatGeneric<FPP,Cpu> & B, Faust::MatDense<FPP,Cpu> & C, const FPP alpha=FPP(1.0), const FPP beta=FPP(0.0), const char typeA='N', const char typeB='N');
void gemm_gen(const Faust::MatGeneric<FPP,Cpu> & A,const Faust::MatGeneric<FPP,Cpu> & B, Faust::MatDense<FPP,Cpu> & C, const FPP alpha/*=FPP(1.0)*/, const FPP beta/*=FPP(0.0)*/, const char typeA/*='N'*/, const char typeB/*='N'*/);
//! \fn spgemm
//! \brief performs Sparse matrices multiplication
......
......@@ -814,7 +814,7 @@ FPP Faust::power_iteration(const Faust::LinearOperator<FPP,Cpu> & A, const faus
handleError("linear_algebra "," power_iteration : Faust::Transform<FPP,Cpu> 1 must be a squared matrix");
}
Faust::Vect<FPP,Cpu> xk(nb_col);
xk.setOnes();
xk.setRand(); // most likely avoids to be orthogonal with the eigenvector // better than setOnes
Faust::Vect<FPP,Cpu> xk_norm(nb_col);
FPP lambda_old=1.0;
FPP lambda = 0.0;
......@@ -825,6 +825,7 @@ FPP Faust::power_iteration(const Faust::LinearOperator<FPP,Cpu> & A, const faus
i++;
lambda_old = lambda;
xk_norm = xk;
std::cout << "xk norm:" << xk.norm() << std::endl;
xk_norm.normalize();
xk = A.multiply(xk_norm);
lambda = xk_norm.dot(xk);
......@@ -865,6 +866,10 @@ namespace Faust
{
template<typename FPP> void gemm_gen(const MatGeneric<FPP, Cpu>& A, const MatGeneric<FPP, Cpu>& B, MatDense<FPP, Cpu>& out, const FPP alpha/*=FPP(1.0)*/, const FPP beta/*=(0.0)*/, const char opA/*='N'*/, const char opB/*='N'*/)
{
if(opA != 'N' && opA != 'T' && opA != 'H')
throw std::runtime_error("opA must be among 'N', 'T', 'H'");
if(opB != 'N' && opB != 'T' && opB != 'H')
throw std::runtime_error("opB must be among 'N', 'T', 'H'");
if(dynamic_cast<const MatDense<FPP, Cpu>*>(&A) != nullptr && dynamic_cast<const MatDense<FPP, Cpu>*>(&B) != nullptr)
{
gemm<FPP>(dynamic_cast<const MatDense<FPP, Cpu>&>(A), dynamic_cast<const MatDense<FPP, Cpu>&>(B), out, alpha, beta, opA, opB);
......@@ -917,10 +922,124 @@ namespace Faust
out.mat = alpha*to_eigen_sp(dynamic_cast<const MatSparse<FPP, Cpu>&>(A).mat, opA)*to_eigen_sp(dynamic_cast<const MatSparse<FPP, Cpu>&>(B).mat, opB);
#endif
}
else if(dynamic_cast<const MatBSR<FPP, Cpu>*>(&A))
{
//TODO: refactor
const MatDense<FPP, Cpu>* dsB = nullptr;
const MatSparse<FPP, Cpu>* spB = nullptr;
const MatBSR<FPP, Cpu>* bsrB = nullptr;
const MatBSR<FPP, Cpu>* bsrA = dynamic_cast<const MatBSR<FPP, Cpu>*>(&A);
if(opA == 'N')
{
Eigen::Matrix<FPP, Eigen::Dynamic, Eigen::Dynamic> dsm;
if(spB = dynamic_cast<const MatSparse<FPP, Cpu>*>(&B))
{
if(opB == 'N')
dsm = bsrA->bmat.mul(spB->mat);
else if (opB == 'T')
dsm = bsrA->bmat.mul(spB->mat).transpose();
else if (opB == 'H')
dsm = bsrA->bmat.mul(spB->mat).adjoint();
}
else if(dsB = dynamic_cast<const MatDense<FPP, Cpu>*>(&B))
{
if(opB == 'N')
dsm = bsrA->bmat.mul(dsB->mat);
else if (opB == 'T')
dsm = bsrA->bmat.mul(dsB->mat).transpose();
else if (opB == 'H')
dsm = bsrA->bmat.mul(dsB->mat).adjoint();
}
else if(bsrB = dynamic_cast<const MatBSR<FPP, Cpu>*>(&B))
if(opB == 'N')
dsm = alpha*bsrA->bmat.mul(bsrB->bmat);
else if (opB == 'T')
dsm = alpha*bsrA->bmat.mul(bsrB->bmat).transpose();
else // opB == H
dsm = alpha*bsrA->bmat.mul(bsrB->bmat).adjoint();
if(beta == FPP(0))
{
out.mat = alpha*dsm;
}
else
{
out.mat *= beta;
out.mat += alpha*dsm;
}
}
else if(opA == 'T')
{
MatBSR<FPP, Cpu> transpA(*bsrA);
transpA.transpose();
return gemm_gen(A, B, out, alpha, beta, 'N', opB);
}
else if(opA == 'H')
{
MatBSR<FPP, Cpu> adjA(*bsrA);
adjA.adjoint();
return gemm_gen(A, B, out, alpha, beta, 'N', opB);
}
}
else if(dynamic_cast<const MatBSR<FPP, Cpu>*>(&A))
{
//TODO: refactor
const MatBSR<FPP, Cpu>* bsrA = dynamic_cast<const MatBSR<FPP, Cpu>*>(&A);
if(opA == 'N' && opB == 'N')
{
gemm_gen(B, A, out, alpha, beta, 'T', 'T');
out.transpose();
}
else if(opA == 'N' && opB == 'T')
{
gemm_gen(B, A, out, alpha, beta, 'N', 'T');
out.transpose();
}
else if(opA == 'N' && opB == 'H')
{
gemm_gen(B, A, out, alpha, beta, 'N', 'H');
out.adjoint();
}
if(opA == 'T' && opB == 'N')
{
gemm_gen(B, A, out, alpha, beta, 'T', 'N');
out.transpose();
}
else if(opA == 'T' && opB == 'T')
{
gemm_gen(B, A, out, alpha, beta, 'N', 'N');
out.transpose();
}
else if(opA == 'T' && opB == 'H')
{
MatBSR<FPP, Cpu> A_conj(*bsrA);
A_conj.conjugate();
gemm_gen(B, A_conj, out, alpha, beta, 'N', 'N');
out.adjoint();
}
if(opA == 'H' && opB == 'N')
{
gemm_gen(B, A, out, alpha, beta, 'H', 'N');
out.adjoint();
}
else if(opA == 'H' && opB == 'T')
{
MatBSR<FPP, Cpu> A_conj(*bsrA);
A_conj.conjugate();
gemm_gen(B, A_conj, out, alpha, beta, 'N', 'N');
out.transpose();
}
else if(opA == 'H' && opB == 'H')
{
gemm_gen(B, A, out, alpha, beta, 'N', 'N');
out.adjoint();
}
}
else
{
throw std::runtime_error("faust_linear_algebra mul function doesn't handle other type of factors than MatDense or MatSparse");
throw std::runtime_error("faust_linear_algebra mul function doesn't handle other type of factors than MatDense, MatSparse or MatBSR.");
}
out.update_dims();
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment