Mentions légales du service

Skip to content
Snippets Groups Projects
Commit a9535e3b authored by hhakim's avatar hhakim
Browse files

Add the matfaust wrapper to palm4msa2020 with GPU2 backend.

parent b0e1d507
Branches
Tags
No related merge requests found
......@@ -5,6 +5,7 @@
%> @param M the dense matrix to factorize.
%> @param p the ParamsPalm4MSA instance to define the algorithm parameters.
%> @param 'backend',int (optional) the backend (the C++ implementation) chosen. Must be 2016 (the default) or 2020 (which should be quicker for certain configurations - e.g. factorizing a Hadamard matrix).
%> @param 'gpu', bool (optional) set to true to execute the algorithm using the GPU implementation. This options is only available when backend==2020.
%>
%> @retval F the Faust object result of the factorization.
%> @retval [F, lambda] = palm4msa(M, p) to optionally get lambda (scale).
......@@ -45,19 +46,31 @@ function [F,lambda] = palm4msa(M, p, varargin)
mex_params = p.to_mex_struct(M);
backend = 2016;
nargin = length(varargin);
gpu = false;
if(nargin > 0)
backend = varargin{1};
if(strcmp('backend', backend))
if(nargin < 2)
error('keyword argument ''backend'' must be followed by 2016 or 2020')
else
backend = varargin{2};
for i=1:nargin
switch(varargin{i})
case 'backend'
if(nargin < i+1)
error('keyword argument ''backend'' must be followed by 2016 or 2020')
else
backend = varargin{i+1};
end
case 'gpu'
if(nargin == i || ~ islogical(varargin{i+1}))
error('gpu keyword argument is not followed by a logical')
else
gpu = varargin{i+1};
end
end
end
if(~ (isscalar(backend) && floor(backend) == backend) || backend ~= 2016 && backend ~= 2020)
backend
error('backend must be a int equal to 2016 or 2020')
end
if(backend ~= 2020 && gpu == true)
error('GPU implementation is only available for 2020 backend.')
end
end
if(backend == 2016)
if(isreal(M))
......@@ -67,7 +80,11 @@ function [F,lambda] = palm4msa(M, p, varargin)
end
elseif(backend == 2020)
if(isreal(M))
[lambda, core_obj] = mexPALM4MSA2020Real(mex_params);
if(gpu)
[lambda, core_obj] = mexPALM4MSA2020_gpu2Real(mex_params);
else
[lambda, core_obj] = mexPALM4MSA2020Real(mex_params);
end
else
error('backend 2020 doesn''t handle yet the complex matrices')
end
......
......@@ -120,7 +120,9 @@ foreach(SCALAR_AND_FSUFFIX double:Real std::complex<double>:Cplx) # TODO: float
configure_file(${FAUST_MATLAB_MEX_SRC_DIR}/mexHierarchical2020.cpp.in ${FAUST_MATLAB_MEX_SRC_DIR}/mexHierarchical2020${FSUFFIX}.cpp @ONLY)
configure_file(${FAUST_MATLAB_DOC_SRC_DIR}/mexHierarchical2020.m.in ${FAUST_MATLAB_DOC_SRC_DIR}/mexHierarchical2020${FSUFFIX}.m @ONLY)
configure_file(${FAUST_MATLAB_MEX_SRC_DIR}/mexPALM4MSA2020.cpp.in ${FAUST_MATLAB_MEX_SRC_DIR}/mexPALM4MSA2020${FSUFFIX}.cpp @ONLY)
configure_file(${FAUST_MATLAB_MEX_SRC_DIR}/mexPALM4MSA2020_gpu2.cpp.in ${FAUST_MATLAB_MEX_SRC_DIR}/mexPALM4MSA2020_gpu2${FSUFFIX}.cpp @ONLY)
configure_file(${FAUST_MATLAB_DOC_SRC_DIR}/mexPALM4MSA2020.m.in ${FAUST_MATLAB_DOC_SRC_DIR}/mexPALM4MSA2020${FSUFFIX}.m @ONLY)
configure_file(${FAUST_MATLAB_DOC_SRC_DIR}/mexPALM4MSA2020_gpu2.m.in ${FAUST_MATLAB_DOC_SRC_DIR}/mexPALM4MSA2020_gpu2${FSUFFIX}.m @ONLY)
endif()
# copy the *.m for factorization now, because we have the FSUFFIX in hands
configure_file(${FAUST_MATLAB_DOC_SRC_DIR}/mexHierarchical_fact.m.in ${FAUST_MATLAB_DOC_SRC_DIR}/mexHierarchical_fact${FSUFFIX}.m COPYONLY)
......
/****************************************************************************/
/* Description: */
/* For more information on the FAuST Project, please visit the website */
/* of the project : <http://faust.inria.fr> */
/* */
/* License: */
/* Copyright (2020): Hakim HADJ-DJILANI */
/* Nicolas Bellot, Adrien Leman, Thomas Gautrais, */
/* Luc Le Magoarou, Remi Gribonval */
/* INRIA Rennes, FRANCE */
/* http://www.inria.fr/ */
/* */
/* The FAuST Toolbox is distributed under the terms of the GNU Affero */
/* General Public License. */
/* This program is free software: you can redistribute it and/or modify */
/* it under the terms of the GNU Affero General Public License as */
/* published by the Free Software Foundation. */
/* */
/* This program is distributed in the hope that it will be useful, but */
/* WITHOUT ANY WARRANTY; without even the implied warranty of */
/* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. */
/* See the GNU Affero General Public License for more details. */
/* */
/* You should have received a copy of the GNU Affero General Public */
/* License along with this program. */
/* If not, see <http://www.gnu.org/licenses/>. */
/* */
/* Contacts: */
/* Nicolas Bellot : nicolas.bellot@inria.fr */
/* Adrien Leman : adrien.leman@inria.fr */
/* Thomas Gautrais : thomas.gautrais@inria.fr */
/* Luc Le Magoarou : luc.le-magoarou@inria.fr */
/* Remi Gribonval : remi.gribonval@inria.fr */
/* */
/* References: */
/* [1] Le Magoarou L. and Gribonval R., "Flexible multi-layer sparse */
/* approximations of matrices and applications", Journal of Selected */
/* Topics in Signal Processing, 2016. */
/* <https://hal.archives-ouvertes.fr/hal-01167948v1> */
#include "mex.h"
//#include "mexutils.h"
#include "faust_gpu_mod_utils.h"
#include "faust_MatDense_gpu.h"
#include "faust_MatDense.h"
#include <vector>
#include <string>
#include <algorithm>
#include "faust_constant.h"
#include "faust_Palm4MSA.h"
#include <stdexcept>
#include "mx2Faust.h"
#include "faust2Mx.h"
#include "mx2Faust.h"
#include "faust_TransformHelper.h"
#include "class_handle.hpp"
#include "faust_palm4msa2020.h"
#include "faust_ConstraintGeneric.h"
#include "faust_ConstraintFPP.h"
using namespace Faust;
void testCoherencePalm4MSA(const mxArray* params,std::vector<bool> & presentFields);
typedef @FAUST_SCALAR@ SCALAR;
typedef @FACT_FPP@ FPP2;
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
//Faust::enable_gpu_mod();
#ifdef FAUST_VERBOSE
if (typeid(SCALAR) == typeid(float))
{
std::cout<<"SCALAR == float"<<std::endl;
}
if (typeid(SCALAR) == typeid(double))
{
std::cout<<"SCALAR == double"<<std::endl;
}
system("sleep 7");
#endif
if (nrhs != 1)
{
mexErrMsgTxt("Bad Number of inputs arguments");
}
if(!mxIsStruct(prhs[0]))
{
mexErrMsgTxt("Input must be a structure.");
}
std::vector<bool> presentFields;
testCoherencePALM4MSA(prhs[0],presentFields);
// mexPrintf(" NUMBER FIELDS %d\n",presentFields.size());
// ///////////// PALM4MSA LAUNCH ///////////////
try
{
// // creation des parametres
auto params = mxArray2FaustParamsPALM4MSA<SCALAR,FPP2>(prhs[0], presentFields);
if(params->isVerbose) params->Display();
SCALAR lambda = params->init_lambda;
cout << "mark 1" << endl;
TransformHelper<SCALAR, GPU2>* F = new TransformHelper<SCALAR, GPU2>(), *F_lambda = nullptr;
cout << "mark 2" << endl;
//TODO: use_csr, packing_RL to add to ParamsPalm and parsing of matlab_params
bool packing_RL = true;
bool use_csr = true;
//TODO: the constness should be kept
std::vector<Faust::ConstraintGeneric*> noconst_cons;
for(auto cons: params->cons)
{
noconst_cons.push_back(const_cast<Faust::ConstraintGeneric*>(cons));
if(cons->is_constraint_parameter_int<SCALAR,GPU2>())
(dynamic_cast<Faust::ConstraintInt<SCALAR,GPU2>*>(*(noconst_cons.end()-1)))->Display();
else if(cons->is_constraint_parameter_real<SCALAR,GPU2>())
(dynamic_cast<Faust::ConstraintFPP<SCALAR,GPU2>*>(*(noconst_cons.end()-1)))->Display();
}
cout << "mark 3" << endl;
palm4msa2<SCALAR,GPU2>(params->data, noconst_cons, *F, lambda, params->stop_crit, params->isUpdateWayR2L,
use_csr, packing_RL, /* compute_2norm_on_array */ false, params->norm2_threshold,
params->norm2_max_iter, params->isConstantStepSize, params->step_size);
cout << "mark 4" << endl;
Faust::MatDense<FPP2,GPU2> mat1x1Lambda = Faust::MatDense<FPP2, GPU2>(1, 1, &lambda);
Faust::MatDense<FPP2,Cpu> cpu_mat1x1Lambda;
plhs[0] = FaustMat2mxArray(cpu_mat1x1Lambda);
F_lambda = F->multiply(lambda);
delete F;
F = F_lambda;
plhs[1] = convertPtr2Mat<Faust::TransformHelper<SCALAR, GPU2>>(F);
delete params;
}
catch (const std::exception& e)
{
plhs[1] = nullptr;
mexErrMsgTxt(e.what());
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment