Mentions légales du service

Skip to content
Snippets Groups Projects
Commit a78b0576 authored by hhakim's avatar hhakim
Browse files

Add BSRMat::apply_op to factorize adjoint and transpose methods.

parent 90909b47
Branches
No related tags found
No related merge requests found
......@@ -27,6 +27,8 @@ using Real = typename Eigen::NumTraits<T>::Real;
#endif
template<typename T,int BlockStorageOrder=0> class BSRMat;
template<typename FPP, FDevice DEVICE> class Transform;
namespace Faust
{
......@@ -37,6 +39,8 @@ namespace Faust
class MatBSR<FPP,Cpu> : public MatGeneric<FPP,Cpu>
{
friend void gemm_gen<>(const MatGeneric<FPP, Cpu>& A, const MatGeneric<FPP, Cpu>& B, MatDense<FPP, Cpu>& out, const FPP alpha/*=FPP(1.0)*/, const FPP beta/*=(0.0)*/, const char opA/*='N'*/, const char opB/*='N'*/);
friend Transform<FPP,Cpu>; // TODO: limit to needed member functions only
BSRMat<FPP> bmat; // low-level BSRMat
MatBSR() : MatGeneric<FPP, Cpu>() {}
MatBSR(BSRMat<FPP>& mat);
......@@ -254,7 +258,16 @@ class BSRMat
* Transpose
*/
BSRMat<T, BlockStorageOrder> transpose(const bool inplace=false);
BSRMat<T, BlockStorageOrder> conjugate(const bool inplace=false);
/**
* Adjoint
*/
BSRMat<T, BlockStorageOrder> adjoint(const bool inplace=false);
/**
* \param op: 'N' (no-op), 'T' (tranpose), 'H' (adjoint), 'C' (conjugate)
*/
BSRMat<T, BlockStorageOrder> apply_op(const char op, const bool inplace=false);
/**
* \param m: matrix number of rows.
* \param n: matrix number of columns.
......
......@@ -205,9 +205,9 @@ namespace Faust
template <typename FPP>
void MatBSR<FPP,Cpu>::adjoint()
{
//TODO: it would be more efficient to transconjugate in one shot
transpose();
conjugate();
bmat.adjoint(/*inplace*/ true);
this->dim1 = bmat.m;
this->dim2 = bmat.n;
}
template <typename FPP>
......@@ -756,7 +756,46 @@ Real<T> BSRMat<T, BlockStorageOrder>::normL1() const
template<typename T, int BlockStorageOrder>
BSRMat<T, BlockStorageOrder> BSRMat<T, BlockStorageOrder>::transpose(const bool inplace/*=false*/)
{
return apply_op('T', inplace);
}
template<typename T, int BlockStorageOrder>
BSRMat<T, BlockStorageOrder> BSRMat<T, BlockStorageOrder>::conjugate(const bool inplace/*="false"*/)
{
if(inplace)
{
DenseMatMap<T> data_mat(data, bnnz*bm, bn);
data_mat = DenseMatMap<T>(data, bnnz*bm, bn).conjugate();
return *this;
}
else
{
BSRMat<T, BlockStorageOrder> cbmat(*this);
cbmat.conjugate(/*inplace=*/true);
return cbmat;
}
}
template<typename T, int BlockStorageOrder>
BSRMat<T, BlockStorageOrder> BSRMat<T, BlockStorageOrder>::adjoint(const bool inplace/*="false"*/)
{
return apply_op('H', inplace);
}
template<typename T, int BlockStorageOrder>
BSRMat<T, BlockStorageOrder> BSRMat<T, BlockStorageOrder>::apply_op(const char op, const bool inplace/*="false"*/)
{
if(op != 'C' && op != 'H' && op != 'T' && op != 'N')
throw std::runtime_error("BSRMat::apply_op: unknown op.");
if(op == 'N')
if(inplace)
return *this;
else
return BSRMat<T, BlockStorageOrder>(*this);
if(op == 'C')
return conjugate(inplace);
// op == 'T' or 'H'
if(inplace)
{
// init tbmat buffers and attributes
......@@ -777,12 +816,15 @@ BSRMat<T, BlockStorageOrder> BSRMat<T, BlockStorageOrder>::transpose(const bool
for(int j=0;j<this->b_per_coldim;j++)
{
int i = 0; // number of blocks found so far in this j-th block column
iter_block([this, &j, &i, &t_block_offset, &data, &bcolinds, &browptr](int mat_row_id, int mat_col_id, int block_offset)
iter_block([this, &op, &j, &i, &t_block_offset, &data, &bcolinds, &browptr](int mat_row_id, int mat_col_id, int block_offset)
{
if(mat_col_id/this->bn == j)
{
DenseMatMap<T> block(data+t_block_offset*this->bn*this->bm, this->bn, this->bm);
block = DenseMatMap<T>(this->data+block_offset*this->bm*this->bn, this->bm, this->bn).transpose().eval();
if(op == 'H')
block = DenseMatMap<T>(this->data+block_offset*this->bm*this->bn, this->bm, this->bn).adjoint().eval();
else if(op == 'T')
block = DenseMatMap<T>(this->data+block_offset*this->bm*this->bn, this->bm, this->bn).adjoint().eval();
bcolinds[t_block_offset++] = mat_row_id/this->bm;
i++;
}
......@@ -807,29 +849,14 @@ BSRMat<T, BlockStorageOrder> BSRMat<T, BlockStorageOrder>::transpose(const bool
else
{
BSRMat<T, BlockStorageOrder> tbmat(*this);
tbmat.transpose(/*inplace=*/true);
if(op == 'T')
tbmat.transpose(/*inplace=*/true);
else
tbmat.adjoint(/*inplace=*/true);
return tbmat;
}
}
template<typename T, int BlockStorageOrder>
BSRMat<T, BlockStorageOrder> BSRMat<T, BlockStorageOrder>::conjugate(const bool inplace/*="false"*/)
{
if(inplace)
{
DenseMatMap<T> data_mat(data, bnnz*bm, bn);
data_mat = DenseMatMap<T>(data, bnnz*bm, bn).conjugate();
return *this;
}
else
{
BSRMat<T, BlockStorageOrder> cbmat(*this);
cbmat.conjugate(/*inplace=*/true);
return cbmat;
}
}
template<typename T, int BlockStorageOrder>
bool BSRMat<T, BlockStorageOrder>::contains_nan() const
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment