Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 9873ab78 authored by hhakim's avatar hhakim
Browse files

Refactor mex svdtj function into Faust C++ core (in order to be reusable for pyfaust).

parent 9b0eb8f2
Branches
Tags
No related merge requests found
......@@ -37,6 +37,7 @@
#include "faust_GivensFGFTParallel.h"
#include "faust_GivensFGFTComplex.h"
#include "faust_GivensFGFTParallelComplex.h"
#include "faust_SVDTJ.h"
#include "faust_TransformHelper.h"
#include "faust_linear_algebra.h"
#include "class_handle.hpp"
......@@ -94,116 +95,28 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
void svdtj(const mxArray* matlab_matrix, int J, int t, double tol, unsigned int verbosity, bool relErr, int order, mxArray **plhs)
{
Faust::MatGeneric<SCALAR,Cpu>* M;
Faust::MatDense<SCALAR,Cpu> dM;
Faust::MatDense<SCALAR, Cpu> dM_M, dMM_; // M'*M, M*M'
Faust::MatSparse<SCALAR,Cpu> sM;
Faust::MatDense<SCALAR, Cpu> sM_M, sMM_; // M'*M, M*M'
Faust::BlasHandle<Cpu> blas_handle;
Faust::SpBlasHandle<Cpu> spblas_handle;
Faust::GivensFGFT<SCALAR, Cpu, FPP2>* algoW1;
Faust::GivensFGFT<SCALAR, Cpu, FPP2>* algoW2;
TransformHelper<SCALAR,Cpu> *U, *V;
Faust::Vect<SCALAR,Cpu>* S;
try{
if (mxIsSparse(matlab_matrix))
{
mxArray2FaustspMat(matlab_matrix,sM);
dM = sM; //TODO: optimize
spgemm(sM, dM, dM_M, 1.0, 0.0, 'T', 'N');
spgemm(sM, dM, dMM_, 1.0, 0.0, 'N', 'T');
M = &sM;
if(t <= 1)
{
algoW1 = new GivensFGFT<SCALAR,Cpu,FPP2>(dMM_, J, verbosity, tol, relErr);
algoW2 = new GivensFGFT<SCALAR,Cpu,FPP2>(dM_M, J, verbosity, tol, relErr);
}
else
{
algoW1 = new GivensFGFTParallel<SCALAR,Cpu,FPP2>(dMM_, J, t, verbosity, tol, relErr);
algoW2 = new GivensFGFTParallel<SCALAR,Cpu,FPP2>(dM_M, J, t, verbosity, tol, relErr);
}
svdtj(sM, J, t, tol, verbosity, relErr, order, &U, &V, &S);
}else
{
mxArray2FaustMat(matlab_matrix, dM);
gemm(dM, dM, dM_M, 1.0, 0.0, 'T', 'N');
gemm(dM, dM, dMM_, 1.0, 0.0, 'N', 'T');
M = &dM;
if(t <= 1)
{
algoW1 = new GivensFGFT<SCALAR,Cpu,FPP2>(dMM_, J, verbosity, tol, relErr);
algoW2 = new GivensFGFT<SCALAR,Cpu,FPP2>(dM_M, J, verbosity, tol, relErr);
}
else
{
algoW1 = new GivensFGFTParallel<SCALAR,Cpu,FPP2>(dMM_, J, t, verbosity, tol, relErr);
algoW2 = new GivensFGFTParallel<SCALAR,Cpu,FPP2>(dM_M, J, t, verbosity, tol, relErr);
}
}
//TODO: parallelize with OpenMP
algoW1->compute_facts();
algoW2->compute_facts();
Faust::Vect<SCALAR,Cpu> S(M->getNbRow());
Faust::Transform<SCALAR,Cpu> transW1 = std::move(algoW1->get_transform(order));
TransformHelper<SCALAR,Cpu> *thW1 = new TransformHelper<SCALAR,Cpu>(transW1, true); // true is for moving and not copying the Transform object into TransformHelper (optimization possible cause we know the original object won't be used later)
Faust::Transform<SCALAR,Cpu> transW2 = std::move(algoW2->get_transform(order));
TransformHelper<SCALAR,Cpu> *thW2 = new TransformHelper<SCALAR,Cpu>(transW2, true); // true is for moving and not copying the Transform object into TransformHelper (optimization possible cause we know the original object won't be used later)
// compute S = W1'*M*W2 = W1'*(W2^T*M)^T
dM.transpose();
Faust::MatDense<SCALAR,Cpu> MW2 = thW2->multiply(dM, /* transpose */ true);
MW2.transpose();
Faust::MatDense<SCALAR,Cpu> W1_MW2 = thW1->multiply(MW2, /* transpose */ true);
// create diagonal vector
for(int i=0;i<S.size();i++){
S.getData()[i] = W1_MW2(i,i);
svdtj(dM, J, t, tol, verbosity, relErr, order, &U, &V, &S);
}
// order D descendently according to the abs value
// and change the sign when the value is negative
// it gives a signed permutation matrix P to append to W1, abs(P2) is append to W2
vector<int> ord_indices;
Faust::Vect<SCALAR,Cpu> ordered_S = Faust::Vect<SCALAR,Cpu>(S.size());
vector<SCALAR> values(S.size());
vector<SCALAR> values2(S.size());
vector<int> col_ids(S.size());
ord_indices.resize(0);
order = 1;
for(int i=0;i<S.size();i++)
ord_indices.push_back(i);
sort(ord_indices.begin(), ord_indices.end(), [S, &order](int i, int j) {
return Faust::fabs(S.getData()[i]) > Faust::fabs(S.getData()[j])?1:0;
});
for(int i=0;i<ord_indices.size();i++)
{
col_ids[i] = i;
ordered_S.getData()[i] = Faust::fabs(S.getData()[ord_indices[i]]);
if(S.getData()[ord_indices[i]] < 0)
values[i] = -1;
else
values[i] = 1;
values2[i] = 1;
}
Faust::MatSparse<SCALAR, Cpu>* PS = new Faust::MatSparse<SCALAR, Cpu>(ord_indices, col_ids, values, M->getNbRow(), M->getNbCol());
thW1->push_back(PS);
Faust::MatSparse<SCALAR, Cpu>* P = new Faust::MatSparse<SCALAR, Cpu>(ord_indices, col_ids, values2, M->getNbRow(), M->getNbCol());
thW2->push_back(P);
delete algoW1;
delete algoW2;
plhs[0] = convertPtr2Mat<Faust::TransformHelper<SCALAR, Cpu>>(thW1);
plhs[1] = FaustVec2mxArray(ordered_S);
plhs[2] = convertPtr2Mat<Faust::TransformHelper<SCALAR, Cpu>>(thW2);
plhs[0] = convertPtr2Mat<Faust::TransformHelper<SCALAR, Cpu>>(U);
plhs[1] = FaustVec2mxArray(*S);
plhs[2] = convertPtr2Mat<Faust::TransformHelper<SCALAR, Cpu>>(V);
delete S; //allocated internally by Faust::svdtj
}
catch (const std::exception& e)
{
......
......@@ -431,34 +431,34 @@ void mxArray2Ptr(const mxArray* mxMat, std::complex<FPP>* & ptr_data)
template<typename FPP>
void concatMatGeneric(const mxArray * mxMat,std::vector<Faust::MatGeneric<FPP,Cpu> *> &list_mat)
{
if (mxMat == NULL)
mexErrMsgTxt("concatMatGeneric : empty matlab matrix");
mexErrMsgTxt("concatMatGeneric : empty matlab matrix");
Faust::MatGeneric<FPP,Cpu> * M;
if (!mxIsSparse(mxMat))
{
Faust::MatDense<FPP,Cpu> denseM;
mxArray2FaustMat(mxMat,denseM);
M=denseM.Clone();
}else
{
Faust::MatSparse<FPP,Cpu> spM;
mxArray2FaustspMat(mxMat,spM);
M=spM.Clone();
}
list_mat.push_back(M);
list_mat.push_back(M);
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment