Mentions légales du service

Skip to content
Snippets Groups Projects
Commit aaa148fd authored by hhakim's avatar hhakim
Browse files

Update matfaust.Faust.factors to handle BSR matrix retrieval (converting it to...

Update matfaust.Faust.factors to handle BSR matrix retrieval (converting it to a sparse matrix because matlab doesn't support BSR matrices).

- Add Transform(HelperGen)::is_fact_bsr
- Rewrite FaustSpMat2MxArray mex function to handle float and complex cases (in addition to double type). When a MatSparse float is converted to mxArray it is also converted to double precision because matlab does not support float precision for sparse matrices.
- Other minor changes.
parent 9b68582a
No related branches found
No related tags found
No related merge requests found
......@@ -123,6 +123,7 @@ namespace Faust
friend class MatDense<FPP,Cpu>;
friend class MatSparse<std::complex<double>, Cpu>;
friend class MatSparse<double, Cpu>;
friend class MatSparse<float, Cpu>;
//friend void MatDense<FPP,Cpu>::operator+=(const MatSparse<FPP,Cpu>& S);
public:
......
......@@ -142,6 +142,7 @@ namespace Faust
MatGeneric<FPP,Cpu>* get_fact(faust_unsigned_int id, const bool cloning_fact = true) const;
bool is_fact_sparse(const faust_unsigned_int id) const;
bool is_fact_dense(const faust_unsigned_int id) const;
bool is_fact_bsr(const faust_unsigned_int id) const;
faust_unsigned_int get_fact_nnz(const faust_unsigned_int id) const;
void get_fact(const faust_unsigned_int id,
const int** row_ids,
......
......@@ -624,8 +624,10 @@ double Faust::Transform<FPP,Cpu>::spectralNorm(const int nbr_iter_max, double th
return 1; // TODO: why?
}else
{
if(this->is_zero)
if(this->is_zero) // The Faust is zero by at least one of its factors
return 0;
// The Faust can still by zero (without any of its factor being)
// this case will be detected in power_iteration
//std::cout<<"Faust debut spectralNorm"<<std::endl;
//std::cout<<"copy constructor"<<std::endl;
Faust::Transform<FPP,Cpu> AtA((*this)); // modif AL <FPP,Cpu>
......@@ -1169,6 +1171,12 @@ bool Faust::Transform<FPP,Cpu>::is_fact_dense(const faust_unsigned_int id) const
return get_fact(id, false)->getType() == MatType::Dense;
}
template<typename FPP>
bool Faust::Transform<FPP,Cpu>::is_fact_bsr(const faust_unsigned_int id) const
{
return get_fact(id, false)->getType() == MatType::BSR;
}
template<typename FPP>
faust_unsigned_int Faust::Transform<FPP,Cpu>::get_fact_nnz(const faust_unsigned_int id) const
{
......
......@@ -425,6 +425,12 @@ namespace Faust
return get_fact(id, /*cloning*/ false)->getType() == Dense;
}
template<>
bool Faust::Transform<@FAUST_SCALAR_FOR_GM@, GPU2>::is_fact_bsr(int id) const
{
return get_fact(id, /*cloning*/ false)->getType() == BSR;
}
template<>
void Transform<@FAUST_SCALAR_FOR_GM@,GPU2>::get_fact(const faust_unsigned_int &id,
@FAUST_SCALAR_FOR_GM@* elts,
......
......@@ -51,6 +51,7 @@ namespace Faust
void get_facts(std::vector<MatGeneric<FPP,GPU2>*> &factors, bool cloning_facts=true) const;
bool is_fact_sparse(int id) const;
bool is_fact_dense(int id) const;
bool is_fact_bsr(int id) const;
void transpose();
int32_t getNbRow()const;
int32_t getNbCol()const;
......
......@@ -63,6 +63,7 @@ namespace Faust
virtual faust_unsigned_int get_fact_nnz(const faust_unsigned_int id) const;
virtual bool is_fact_sparse(const faust_unsigned_int id) const;
virtual bool is_fact_dense(const faust_unsigned_int id) const;
virtual bool is_fact_bsr(const faust_unsigned_int id) const;
virtual MatType get_fact_type(const faust_unsigned_int id) const;
virtual void pack_factors(faust_unsigned_int start_id, faust_unsigned_int end_id, const int mul_order_opt_mode=DEFAULT_L2R)=0;
......
......@@ -130,6 +130,12 @@ namespace Faust
return this->transform->is_fact_sparse(this->is_transposed?size()-id-1:id);
}
template<typename FPP, FDevice DEV>
bool TransformHelperGen<FPP,DEV>::is_fact_bsr(const faust_unsigned_int id) const
{
return this->transform->is_fact_bsr(this->is_transposed?size()-id-1:id);
}
template<typename FPP, FDevice DEV>
MatType TransformHelperGen<FPP,DEV>::get_fact_type(const faust_unsigned_int id) const
{
......
......@@ -1100,6 +1100,7 @@ classdef Faust
%> @retval factors a matrix copy of the i-th factor if i is a single index or a new Faust composed of i-th to the j-th factors of F. The factors copies keep the storage organization of the source matrix (full or sparse).
%>
%> @note Matlab doesn't support float sparse matrices, but matfaust does! Hence when you call Faust.factors on a float sparse Faust to retrieve one factor you'll get a double sparse matrix as a copy of the float sparse matrix.
%> @note As well for BSR matrices that aren't not supported by Matlab, the function can't return a bsr matrix so it rather converts it on the fly to a sparse matrix that is finally returned.
%>
%> @b Example
%> @code
......
......@@ -106,7 +106,7 @@ template<>
void newMxGetData<float>(float*& ptr_out, const mxArray* mxMat)
{
if(mxGetClassID(mxMat) != mxSINGLE_CLASS || mxIsComplex(mxMat))
mexErrMsgTxt("newMxGetData: the mex matrix must be double as the ptr is.");
mexErrMsgTxt("newMxGetData: the mex matrix must be float as the ptr is.");
ptr_out = static_cast<float*> (mxGetSingles(mxMat));
}
......@@ -329,13 +329,45 @@ mxArray* FaustSpMat2mxArray(const Faust::MatSparse<FPP,Cpu>& M)
mxREAL);
mwIndex* ir = mxGetIr(sparseMat);
mwIndex* jc = mxGetJc(sparseMat);
FPP* pr;
const Faust::MatSparse<double, Cpu>* dM = nullptr;
Faust::MatSparse<double, Cpu> dmat;
// matlab sparse matrix cannot be in single/float precision, convert it to double before processing
if(std::is_same<FPP, float>::value) return FaustSpMat2mxArray(M.template to_real<double>());
dM = (const Faust::MatSparse<double, Cpu>*)(&M);
double* pr;
//TODO: and complex case ?
#ifdef MX_HAS_INTERLEAVED_COMPLEX
pr = static_cast<FPP*>(mxGetDoubles(sparseMat));
pr = static_cast<double*>(mxGetDoubles(sparseMat));
#else
pr = static_cast<FPP*>(mxGetPr(sparseMat));
#endif
// sadly we can't do a direct copy into ir and jc because MatSparse uses int type for indices
// (not mwIndex which is in fact a size_t)
// so we need to copy in intermediate buffers and then affect their elements
// into jc, ir
auto tM = *dM;
// transpose M because matlab is in CSC format while Faust is in CSR
tM.transpose();
// tM.copyRowPtr(jc);
// tM.copyColInd(ir);
// tM.copyValuePtr(pr);
tM.copyBufs(jc, ir, pr);
return sparseMat;
}
#ifdef MX_HAS_INTERLEAVED_COMPLEX
template<>
mxArray* FaustSpMat2mxArray(const Faust::MatSparse<std::complex<double>,Cpu>& M)
{
faust_unsigned_int nnz = M.getNonZeros();
mxArray* sparseMat = mxCreateSparse(M.getNbRow(),
M.getNbCol(),
nnz,
mxCOMPLEX);
mwIndex* ir = mxGetIr(sparseMat);
mwIndex* jc = mxGetJc(sparseMat);
std::complex<double>* pr;
pr = reinterpret_cast<std::complex<double>*>(mxGetComplexDoubles(sparseMat));
// sadly we can't do a direct copy into ir and jc because MatSparse uses int type for indices
// (not mwIndex which is in fact a size_t)
// so we need to copy in intermediate buffers and then affect their elements
......@@ -343,12 +375,13 @@ mxArray* FaustSpMat2mxArray(const Faust::MatSparse<FPP,Cpu>& M)
auto tM = M;
// transpose M because matlab is in CSC format while Faust is in CSR
tM.transpose();
// tM.copyRowPtr(jc);
// tM.copyColInd(ir);
// tM.copyValuePtr(pr);
// tM.copyRowPtr(jc);
// tM.copyColInd(ir);
// tM.copyValuePtr(pr);
tM.copyBufs(jc, ir, pr);
return sparseMat;
}
#endif
template<FDevice DEV>
mxArray* transformFact2SparseMxArray(faust_unsigned_int id, Faust::TransformHelper<float, DEV>* core_ptr)
......@@ -361,9 +394,9 @@ mxArray* transformFact2SparseMxArray(faust_unsigned_int id, Faust::TransformHelp
mwIndex* jc = mxGetJc(sparseMat);
// std::cout << "transformFact2SparseMxArray()" << std::endl;
// FPP* pr = static_cast<FPP*>(mxGetDoubles(sparseMat)); //given up because fails to compile with template FPP
double* pr;
float* pr;
#ifdef MX_HAS_INTERLEAVED_COMPLEX
pr = static_cast<double*>(mxGetDoubles(sparseMat));
newMxGetData(pr, sparseMat);
#else
pr = static_cast<double*>(mxGetPr(sparseMat));
#endif
......
......@@ -163,7 +163,7 @@ void mxArray2PtrBase(const mxArray* mxMat, FPP* & ptr_data)
else
nb_element_tmp = mxGetNumberOfElements(mxMat);
const size_t NB_ELEMENTS = nb_element_tmp;
const size_t nb_elts = nb_element_tmp;
if(V_CLASS_ID == mxDOUBLE_CLASS)
......@@ -173,9 +173,9 @@ void mxArray2PtrBase(const mxArray* mxMat, FPP* & ptr_data)
{
if(! is_same<FPP, complex<double>>::value)
mexErrMsgTxt("mxMat is complex double, the output buffer must be complex<double>");
ptr_data = new FPP[NB_ELEMENTS];
ptr_data = new FPP[nb_elts];
mxComplexDouble* mx_cplx_doubles = mxGetComplexDoubles(mxMat);
memcpy(ptr_data, mx_cplx_doubles, sizeof(FPP)*NB_ELEMENTS);
memcpy(ptr_data, mx_cplx_doubles, sizeof(FPP)*nb_elts);
}
else
{
......@@ -183,9 +183,9 @@ void mxArray2PtrBase(const mxArray* mxMat, FPP* & ptr_data)
{
mexErrMsgTxt("mxMat is double, the output buffer must be double");
}
ptr_data = new FPP[NB_ELEMENTS];
ptr_data = new FPP[nb_elts];
mxDouble* mx_doubles = mxGetDoubles(mxMat);
memcpy(ptr_data, mx_doubles, sizeof(FPP)*NB_ELEMENTS);
memcpy(ptr_data, mx_doubles, sizeof(FPP)*nb_elts);
}
}
......@@ -198,17 +198,17 @@ void mxArray2PtrBase(const mxArray* mxMat, FPP* & ptr_data)
{
if(! is_same<FPP, complex<float>>::value)
mexErrMsgTxt("mxMat is complex float, the output buffer must be complex<float>");
ptr_data = new FPP[NB_ELEMENTS];
ptr_data = new FPP[nb_elts];
mxComplexSingle* mx_cplx_floats = mxGetComplexSingles(mxMat);
memcpy(ptr_data, mx_cplx_floats, sizeof(FPP)*NB_ELEMENTS);
memcpy(ptr_data, mx_cplx_floats, sizeof(FPP)*nb_elts);
}
else
{
if(! is_same<FPP, float>::value)
mexErrMsgTxt("mxMat is float, the output buffer must be float");
ptr_data = new FPP[NB_ELEMENTS];
ptr_data = new FPP[nb_elts];
mxSingle* mx_floats = mxGetSingles(mxMat);
memcpy(ptr_data, mx_floats, sizeof(FPP)*NB_ELEMENTS);
memcpy(ptr_data, mx_floats, sizeof(FPP)*nb_elts);
}
}
......@@ -224,9 +224,9 @@ void mxArray2PtrBase(const mxArray* mxMat, FPP* & ptr_data)
if(! is_same<FPP, char>::value)
mexErrMsgTxt("mxMat is int8s, the output buffer must be char");
ptr_data = new FPP[NB_ELEMENTS];
ptr_data = new FPP[nb_elts];
mxInt8* mx_chars = mxGetInt8s(mxMat);
memcpy(ptr_data, mx_chars, sizeof(FPP)*NB_ELEMENTS);
memcpy(ptr_data, mx_chars, sizeof(FPP)*nb_elts);
}
......@@ -235,9 +235,9 @@ void mxArray2PtrBase(const mxArray* mxMat, FPP* & ptr_data)
{
if(! is_same<FPP, unsigned char>::value)
mexErrMsgTxt("mxMat is uint8, the output buffer must be unsigned char");
ptr_data = new FPP[NB_ELEMENTS];
ptr_data = new FPP[nb_elts];
mxUint8* mx_chars = mxGetUint8s(mxMat);
memcpy(ptr_data, mx_chars, sizeof(FPP)*NB_ELEMENTS);
memcpy(ptr_data, mx_chars, sizeof(FPP)*nb_elts);
}
else if(V_CLASS_ID == mxINT16_CLASS)
......@@ -245,9 +245,9 @@ void mxArray2PtrBase(const mxArray* mxMat, FPP* & ptr_data)
{
if(! is_same<FPP, short>::value)
mexErrMsgTxt("mxMat is int16s, the output buffer must be short");
ptr_data = new FPP[NB_ELEMENTS];
ptr_data = new FPP[nb_elts];
mxInt16* mx_chars = mxGetInt16s(mxMat);
memcpy(ptr_data, mx_chars, sizeof(FPP)*NB_ELEMENTS);
memcpy(ptr_data, mx_chars, sizeof(FPP)*nb_elts);
}
else if (V_CLASS_ID == mxUINT16_CLASS)
......@@ -255,41 +255,41 @@ void mxArray2PtrBase(const mxArray* mxMat, FPP* & ptr_data)
{
if(! is_same<FPP, unsigned short>::value)
mexErrMsgTxt("mxMat is uint16, the output buffer must be unsigned short");
ptr_data = new FPP[NB_ELEMENTS];
ptr_data = new FPP[nb_elts];
mxUint16* mx_ints = mxGetUint16s(mxMat);
memcpy(ptr_data, mx_ints, sizeof(FPP)*NB_ELEMENTS);
memcpy(ptr_data, mx_ints, sizeof(FPP)*nb_elts);
}
else if (V_CLASS_ID == mxINT32_CLASS)
{
if(! is_same<FPP, int>::value)
mexErrMsgTxt("mxMat is int32s, the output buffer must be int");
ptr_data = new FPP[NB_ELEMENTS];
ptr_data = new FPP[nb_elts];
mxInt32* mx_ints = mxGetInt32s(mxMat);
memcpy(ptr_data, mx_ints, sizeof(FPP)*NB_ELEMENTS);
memcpy(ptr_data, mx_ints, sizeof(FPP)*nb_elts);
}
else if (V_CLASS_ID == mxUINT32_CLASS)
{
if(! is_same<FPP, unsigned int>::value)
mexErrMsgTxt("mxMat is uint32, the output buffer must be unsigned int");
ptr_data = new FPP[NB_ELEMENTS];
ptr_data = new FPP[nb_elts];
mxUint32* mx_ints = mxGetUint32s(mxMat);
memcpy(ptr_data, mx_ints, sizeof(FPP)*NB_ELEMENTS);
memcpy(ptr_data, mx_ints, sizeof(FPP)*nb_elts);
}
else if (V_CLASS_ID == mxINT64_CLASS)
{
if(! is_same<FPP, long int>::value)
mexErrMsgTxt("mxMat is int64s, the output buffer must be long int");
ptr_data = new FPP[NB_ELEMENTS];
ptr_data = new FPP[nb_elts];
mxInt64* mx_ints = mxGetInt64s(mxMat);
memcpy(ptr_data, mx_ints, sizeof(FPP)*NB_ELEMENTS);
memcpy(ptr_data, mx_ints, sizeof(FPP)*nb_elts);
}
else if (V_CLASS_ID == mxUINT64_CLASS)
{
if(! is_same<FPP, unsigned long int>::value)
mexErrMsgTxt("mxMat is uint64, the output buffer must be unsigned long int");
ptr_data = new FPP[NB_ELEMENTS];
ptr_data = new FPP[nb_elts];
mxUint64* mx_ints = mxGetUint64s(mxMat);
memcpy(ptr_data, mx_ints, sizeof(FPP)*NB_ELEMENTS);
memcpy(ptr_data, mx_ints, sizeof(FPP)*nb_elts);
}
else
mexErrMsgTxt("Unknown matlab type.");
......@@ -316,15 +316,15 @@ void mxArray2PtrBase(const mxArray* mxMat, FPP* & ptr_data,FUNCTOR & mxGetDataFu
else
nb_element_tmp = mxGetNumberOfElements(mxMat);
const size_t NB_ELEMENTS = nb_element_tmp;
const size_t nb_elts = nb_element_tmp;
if(V_CLASS_ID == mxDOUBLE_CLASS)
{
double* ptr_data_tmp = static_cast<double*> (mxGetDataFunc(mxMat));
ptr_data = new FPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data = new FPP[nb_elts];
for (size_t i =0 ; i<nb_elts ; i++)
ptr_data[i] = static_cast<FPP> (ptr_data_tmp[i]);
}
......@@ -335,8 +335,8 @@ void mxArray2PtrBase(const mxArray* mxMat, FPP* & ptr_data,FUNCTOR & mxGetDataFu
float* ptr_data_tmp = static_cast<float*> (mxGetDataFunc(mxMat));
ptr_data = new FPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data = new FPP[nb_elts];
for (size_t i =0 ; i<nb_elts ; i++)
ptr_data[i] = static_cast<FPP> (ptr_data_tmp[i]);
}
......@@ -348,8 +348,8 @@ void mxArray2PtrBase(const mxArray* mxMat, FPP* & ptr_data,FUNCTOR & mxGetDataFu
char* ptr_data_tmp = static_cast<char*> (mxGetDataFunc(mxMat));
ptr_data = new FPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data = new FPP[nb_elts];
for (size_t i =0 ; i<nb_elts ; i++)
ptr_data[i] = static_cast<FPP> (ptr_data_tmp[i]);
}
......@@ -358,8 +358,8 @@ void mxArray2PtrBase(const mxArray* mxMat, FPP* & ptr_data,FUNCTOR & mxGetDataFu
{
unsigned char* ptr_data_tmp = static_cast<unsigned char*> (mxGetDataFunc(mxMat));
ptr_data = new FPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data = new FPP[nb_elts];
for (size_t i =0 ; i<nb_elts ; i++)
ptr_data[i] = static_cast<FPP> (ptr_data_tmp[i]);
}
......@@ -370,8 +370,8 @@ void mxArray2PtrBase(const mxArray* mxMat, FPP* & ptr_data,FUNCTOR & mxGetDataFu
short* ptr_data_tmp = static_cast<short*> (mxGetDataFunc(mxMat));
ptr_data = new FPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data = new FPP[nb_elts];
for (size_t i =0 ; i<nb_elts ; i++)
ptr_data[i] = static_cast<FPP> (ptr_data_tmp[i]);
}
......@@ -379,40 +379,40 @@ void mxArray2PtrBase(const mxArray* mxMat, FPP* & ptr_data,FUNCTOR & mxGetDataFu
{
unsigned short* ptr_data_tmp = static_cast<unsigned short*> (mxGetDataFunc(mxMat));
ptr_data = new FPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data = new FPP[nb_elts];
for (size_t i =0 ; i<nb_elts ; i++)
ptr_data[i] = static_cast<FPP> (ptr_data_tmp[i]);
}
else if (V_CLASS_ID == mxINT32_CLASS)
{
int* ptr_data_tmp = static_cast<int*> (mxGetDataFunc(mxMat));
ptr_data = new FPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data = new FPP[nb_elts];
for (size_t i =0 ; i<nb_elts ; i++)
ptr_data[i] = static_cast<FPP> (ptr_data_tmp[i]);
}
else if (V_CLASS_ID == mxUINT32_CLASS)
{
unsigned int* ptr_data_tmp = static_cast<unsigned int*> (mxGetDataFunc(mxMat));
ptr_data = new FPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data = new FPP[nb_elts];
for (size_t i =0 ; i<nb_elts ; i++)
ptr_data[i] = static_cast<FPP> (ptr_data_tmp[i]);
}
else if (V_CLASS_ID == mxINT64_CLASS)
{
long long* ptr_data_tmp = static_cast<long long*> (mxGetDataFunc(mxMat));
ptr_data = new FPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data = new FPP[nb_elts];
for (size_t i =0 ; i<nb_elts ; i++)
ptr_data[i] = static_cast<FPP> (ptr_data_tmp[i]);
}
else if (V_CLASS_ID == mxUINT64_CLASS)
{
unsigned long long* ptr_data_tmp = static_cast<unsigned long long*> (mxGetDataFunc(mxMat));
ptr_data = new FPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data = new FPP[nb_elts];
for (size_t i =0 ; i<nb_elts ; i++)
ptr_data[i] = static_cast<FPP> (ptr_data_tmp[i]);
}
else
......@@ -456,13 +456,13 @@ void mxArray2Ptr(const mxArray* mxMat, std::complex<FPP>* & ptr_data)
else
nb_element_tmp = mxGetNumberOfElements(mxMat);
const size_t NB_ELEMENTS = nb_element_tmp;
const size_t nb_elts = nb_element_tmp;
// get the real part of the Matlab Matrix
FPP* ptr_real_part_data;
mxArray2Ptr(mxMat,ptr_real_part_data);
ptr_data = new std::complex<FPP>[NB_ELEMENTS];
ptr_data = new std::complex<FPP>[nb_elts];
if (mxIsComplex(mxMat))
{
......@@ -473,7 +473,7 @@ void mxArray2Ptr(const mxArray* mxMat, std::complex<FPP>* & ptr_data)
mxArray2PtrBase(mxMat,ptr_imag_part_data,mxGetImagData);
// copy the values in the output vector
for (int i=0;i < NB_ELEMENTS;i++)
for (int i=0;i < nb_elts;i++)
ptr_data[i]=std::complex<FPP>(ptr_real_part_data[i],ptr_imag_part_data[i]);
......@@ -484,7 +484,7 @@ void mxArray2Ptr(const mxArray* mxMat, std::complex<FPP>* & ptr_data)
{
// Real Matlab Matrix
// copy only the real part of the matrix (the imaginary part is set to zero
for (int i=0;i < NB_ELEMENTS;i++)
for (int i=0;i < nb_elts;i++)
ptr_data[i]=std::complex<FPP>(ptr_real_part_data[i],(FPP) 0.0);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment