Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 9e4f3e16 authored by hhakim's avatar hhakim
Browse files

Make possible to use ParamsPalm4MSA.init_facts for the 2020 implementation of matfaust.

parent 0c1d3ecd
Branches
No related tags found
No related merge requests found
......@@ -79,11 +79,14 @@ function [F,lambda] = palm4msa(M, p, varargin)
[lambda, core_obj] = mexPalm4MSACplx(mex_params);
end
elseif(backend == 2020)
init_faust = matfaust.Faust(p.init_facts);
% no need to keep the ParamsPalm4MSA extracted/generated cell for init_facts
% mex_params = rmfield(mex_params, 'init_facts')
if(isreal(M))
if(gpu)
[lambda, core_obj] = mexPALM4MSA2020_gpu2Real(mex_params);
[lambda, core_obj] = mexPALM4MSA2020_gpu2Real(mex_params, get_handle(init_faust));
else
[lambda, core_obj] = mexPALM4MSA2020Real(mex_params);
[lambda, core_obj] = mexPALM4MSA2020Real(mex_params, get_handle(init_faust));
end
else
error('backend 2020 doesn''t handle yet the complex matrices')
......
......@@ -1991,6 +1991,10 @@ classdef Faust
function set_Fv_mul_mode(self, mode)
set_Fv_mul_mode(self.matrix, mode)
end
function H = get_handle(self)
H = self.matrix.objectHandle;
end
end
methods(Access = private)
%================================================================
......
......@@ -78,11 +78,9 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
system("sleep 7");
#endif
std::cout << "mexPALM4MSA2020" << std::endl;
if (nrhs != 1)
if (nrhs < 1)
{
mexErrMsgTxt("Bad Number of inputs arguments");
mexErrMsgTxt("Bad Number of inputs arguments (must be 1 or 2)");
}
if(!mxIsStruct(prhs[0]))
......@@ -97,7 +95,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
try
{
// // creation des parametres
// PALM4MSA parameters
auto params = mxArray2FaustParamsPALM4MSA<SCALAR,FPP2>(prhs[0], presentFields);
Faust::MHTPParams<SCALAR> mhtp_params;
mxArray2FaustMHTPParams<SCALAR>(prhs[0], mhtp_params);
......@@ -105,7 +103,23 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
// Faust::BlasHandle<Cpu> blas_handle;
SCALAR lambda = params->init_lambda;
TransformHelper<SCALAR, Cpu>* F = new TransformHelper<SCALAR, Cpu>(), *F_lambda = nullptr;
TransformHelper<SCALAR, Cpu>* F = nullptr, *F_lambda = nullptr;
if(nrhs < 2)
F = new TransformHelper<SCALAR, Cpu>();
else
{
// init_facts passed as a Faust in 2nd argument
F = convertMat2Ptr<Faust::TransformHelper<SCALAR,Cpu> >(prhs[1]);
//TODO: understand why this workaround is necessary
// when this copy is not made manually palm4msa2 crashes
std::vector<Faust::MatGeneric<SCALAR,Cpu>*> facts;
for(int i=0;i<F->size();i++)
{
facts.push_back(F->get_gen_fact_nonconst(i));
}
F = new TransformHelper<SCALAR, Cpu>(facts, (SCALAR) 1.0, false, true, true);
}
//TODO: the constness should be kept
std::vector<Faust::ConstraintGeneric*> noconst_cons;
......@@ -126,7 +140,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Faust::MatDense<FPP2,Cpu> mat1x1Lambda = Faust::MatDense<FPP2, Cpu>(&lambda, 1, 1);
plhs[0] = FaustMat2mxArray(mat1x1Lambda);
F_lambda = F->multiply(lambda);
delete F;
delete F; // if init_facts is used F is a copy too
F = F_lambda;
plhs[1] = convertPtr2Mat<Faust::TransformHelper<SCALAR, Cpu>>(F);
delete params;
......
......@@ -68,6 +68,7 @@ typedef @FACT_FPP@ FPP2;
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
//TODO: factorize with mexPALM4MSA2020.cpp.in
Faust::enable_gpu_mod();
#ifdef FAUST_VERBOSE
if (typeid(SCALAR) == typeid(float))
......@@ -83,9 +84,9 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
#endif
if (nrhs != 1)
if (nrhs < 1)
{
mexErrMsgTxt("Bad Number of inputs arguments");
mexErrMsgTxt("Bad Number of inputs arguments (must be 1 or 2)");
}
if(!mxIsStruct(prhs[0]))
......@@ -107,7 +108,24 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
if(params->isVerbose) params->Display();
SCALAR lambda = params->init_lambda;
TransformHelper<SCALAR, GPU2>* F = new TransformHelper<SCALAR, GPU2>(), *F_lambda = nullptr;
TransformHelper<SCALAR, GPU2>* F = nullptr, *F_lambda = nullptr;
if(nrhs < 2)
F = new TransformHelper<SCALAR, GPU2>();
else
{
// init_facts passed as a Faust in 2nd argument
F = convertMat2Ptr<Faust::TransformHelper<SCALAR,GPU2> >(prhs[1]);
//TODO: understand why this workaround is necessary
// when this copy is not made manually palm4msa2 crashes
std::vector<Faust::MatGeneric<SCALAR,GPU2>*> facts;
for(int i=0;i<F->size();i++)
{
facts.push_back(F->get_gen_fact_nonconst(i));
}
F = new TransformHelper<SCALAR, GPU2>(facts, (SCALAR) 1.0, false, true, true);
}
//TODO: the constness should be kept
std::vector<Faust::ConstraintGeneric*> noconst_cons;
for(auto cons: params->cons)
......@@ -145,7 +163,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Faust::MatDense<FPP2,Cpu> mat1x1Lambda = Faust::MatDense<FPP2, Cpu>(1, 1, &lambda);
plhs[0] = FaustMat2mxArray(mat1x1Lambda);
F_lambda = F->multiply(lambda);
delete F;
delete F; // if init_facts is used F is a copy too
F = F_lambda;
auto cpuF = F->tocpu();
plhs[1] = convertPtr2Mat<Faust::TransformHelper<SCALAR, Cpu>>(cpuF);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment