Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 650e45b3 authored by hhakim's avatar hhakim
Browse files

Add matfaust.fact.palm4msa_mhtp (with underlying mex code and API doc).

parent 606b2d82
Branches
Tags
No related merge requests found
......@@ -31,7 +31,7 @@ if(BUILD_DOCUMENTATION)
string(CONCAT DOXYGEN_FILE_PATTERNS "*.cpp *.hpp *.h *.cu *.hu")
endif()
if(BUILD_WRAPPER_MATLAB)
string(CONCAT DOXYGEN_FILE_PATTERNS ${DOXYGEN_FILE_PATTERNS} " Faust.m FaustMulMode.m StoppingCriterion.m ConstraintGeneric.m ConstraintMat.m ConstraintReal.m ConstraintInt.m ConstraintName.m ParamsFact.m ParamsHierarchical.m ParamsPalm4MSA.m FaustFactory.m hadamard.m quickstart.m fft.m bsl.m runtimecmp.m runall.m version.m faust_fact.m ParamsHierarchicalSquareMat.m ParamsHierarchicalRectMat.m license.m omp.m wht.m dft.m eye.m rand.m eigtj.m hierarchical.m fact.m palm4msa.m fgft_givens.m fgft_palm.m svdtj.m splin.m spcol.m proj_gen.m sp.m const.m supp.m hankel.m toeplitz.m circ.m normcol.m normlin.m splincol.m blockdiag.m skperm.m enable_gpu_mod.m isFaust.m poly.m basis.m next.m expm_multiply.m FaustPoly.m MHTPParams.m ") # warning: the space on the end matters
string(CONCAT DOXYGEN_FILE_PATTERNS ${DOXYGEN_FILE_PATTERNS} " Faust.m FaustMulMode.m StoppingCriterion.m ConstraintGeneric.m ConstraintMat.m ConstraintReal.m ConstraintInt.m ConstraintName.m ParamsFact.m ParamsHierarchical.m ParamsPalm4MSA.m FaustFactory.m hadamard.m quickstart.m fft.m bsl.m runtimecmp.m runall.m version.m faust_fact.m ParamsHierarchicalSquareMat.m ParamsHierarchicalRectMat.m license.m omp.m wht.m dft.m eye.m rand.m eigtj.m hierarchical.m fact.m palm4msa.m fgft_givens.m fgft_palm.m svdtj.m splin.m spcol.m proj_gen.m sp.m const.m supp.m hankel.m toeplitz.m circ.m normcol.m normlin.m splincol.m blockdiag.m skperm.m enable_gpu_mod.m isFaust.m poly.m basis.m next.m expm_multiply.m FaustPoly.m MHTPParams.m palm4msa_mhtp.m ") # warning: the space on the end matters
endif()
if(BUILD_WRAPPER_PYTHON)
string(CONCAT DOXYGEN_FILE_PATTERNS ${DOXYGEN_FILE_PATTERNS} "__init__.py factparams.py demo.py tools.py fact.py proj.py poly.py")
......
......@@ -3,7 +3,7 @@
%>
%>
%> @param M the dense matrix to factorize.
%> @param p the ParamsPalm4MSA instance to define the algorithm parameters.
%> @param p the matfaust.factparams.ParamsPalm4MSA instance to define the algorithm parameters.
%> @param 'backend',int (optional) the backend (the C++ implementation) chosen. Must be 2016 (the default) or 2020 (which should be quicker for certain configurations - e.g. factorizing a Hadamard matrix).
%> @param 'gpu', bool (optional) set to true to execute the algorithm using the GPU implementation. This options is only available when backend==2020.
%>
......
% experimental block start
%==========================================================================
%> @brief Runs the MHTP-PALM4MSA algorithm to factorize the matrix M.
%>
%> MHTP stands for Multilinear Hard Tresholding Pursuit. This is a generalization of the Bilinear HTP algorithm describe in [1].
%>
%> [1] Quoc-Tung Le, Rémi Gribonval. Structured Support Exploration For Multilayer Sparse Matrix Fac- torization. ICASSP 2021 - IEEE International Conference on Acoustics, Speech and Signal Processing, Jun 2021, Toronto, Ontario, Canada. pp.1-5. <a href="https://hal.inria.fr/hal-03132013/document">hal-03132013</a>
%>
%> @param M the dense matrix to factorize.
%> @param palm4msa_p the matfaust.factparams.ParamsPalm4MSA instance to define the algorithm parameters.
%> @param mthp_p the matfaust.factparams.MHTPParams instance to define the MHTP algorithm parameters.
%> @param 'gpu', bool (optional) set to true to execute the algorithm using the GPU implementation. This options is only available when backend==2020.
%>
%> @retval F the Faust object result of the factorization.
%> @retval [F, lambda] = palm4msa(M, p) to optionally get lambda (scale).
%>
%> @b Example
%> @code
%> >> % in a matlab terminal
%> >> import matfaust.fact.palm4msa_mhtp
%> >> import matfaust.factparams.*
%> >> import matfaust.proj.*
%> >> M = rand(500,32);
%> >> projs = { splin([500,32], 5), normcol([32,32], 1.0)};
%> >> stop_crit = StoppingCriterion(200);
%> >> param = ParamsPalm4MSA(projs, stop_crit);
%> >> mhtp_param = MHTPParams('num_its', 60, 'palm4msa_period', 10);
%> >> G = palm4msa_mhtp(M, param, mhtp_param)
%> @endcode
%>
%> G =
%>
%> Faust size 500x32, density 0.17825, nnz_sum 2852, 2 factor(s):
%> - FACTOR 0 (real) SPARSE, size 500x32, density 0.15625, nnz 2500
%> - FACTOR 1 (real) SPARSE, size 32x32, density 0.34375, nnz 352
%>
%==========================================================================
function [F,lambda] = palm4msa_mhtp(M, palm4msa_p, mhtp_p, varargin)
palm4msa_p.use_MHTP = mhtp_p;
[F, lambda] = matfaust.fact.palm4msa(M, palm4msa_p, 'backend', 2020, varargin{:});
end
% experimental block end
% experimental block start
% =========================================================
%> @brief This class defines the set of parameters to run the MHTP-PAL4MSA algorithm.
%>
%> See also matfaust.fact.palm4msa_mhtp, matfaust.fact.hierarchical_mhtp.
% =========================================================
classdef MHTPParams
properties (SetAccess = public, Hidden = false)
num_its;
constant_step_size;
step_size;
palm4msa_period;
updating_lambda;
end
methods
% =========================================================
%> Constructor of the MHTPParams class.
%>
%> See also matfaust.fact.palm4msa_mhtp, matfaust.fact.hierarchical_mhtp.
%>
%> @param 'num_its', int: (optional) the number of iterations to run the MHTP algorithm.
%> @param 'constant_step_size', bool: (optional) true to use a constant step for the gradient descent, False otherwise. If false the step size is computed dynamically along the iteration (according to a Lipschitz criterion).
%> @param 'step_size', real: (optional) The step size used when constant_step_size==true.
%> @param 'palm4msa_period', int: (optional) The period (in term of iterations) according to the MHTP algorithm is ran (i.e.: 0 <= i < N being the PALM4MSA iteration, MHTP is launched every i = 0 (mod palm4msa_period). Hence the algorithm is ran one time at least – at PALM4MSA iteration 0).
%> @param 'updating_lambda', bool: (optional) if true then the scale factor of the Faust resulting of the factorization is updated after each iteration of MHTP (otherwise it never changes during the whole MHTP execution).
% =========================================================
function p = MHTPParams(varargin)
argc = length(varargin);
% default parameter values
p.num_its = 50;
p.constant_step_size = false;
p.step_size = 1e-3;
p.palm4msa_period = 1000;
p.updating_lambda = true;
if(argc > 0)
for i=1:2:argc
if(argc > i)
% next arg (value corresponding to the key varargin{i})
tmparg = varargin{i+1};
end
switch(varargin{i})
case 'step_size'
if(argc == i || ~ isscalar(tmparg) || ~ isreal(tmparg) || tmparg < 0)
error('step_size argument must be followed by a positive real value')
else
p.step_size = tmparg;
end
case 'num_its'
if(argc == i || ~ isscalar(tmparg) || ~isreal(tmparg) || tmparg < 0 || tmparg-floor(tmparg) > 0)
error('num_its argument must be followed by an integer')
else
p.num_its = tmparg;
end
case 'palm4msa_period'
if(argc == i || ~ isscalar(tmparg) || ~isreal(tmparg) || tmparg < 0 || tmparg-floor(tmparg) > 0)
error('palm4msa_period argument must be followed by an integer')
else
p.palm4msa_period = tmparg;
end
case 'updating_lambda'
if(argc == i || ~ islogical(tmparg))
error('updating_lambda argument must be followed by a logical')
else
p.updating_lambda = tmparg;
end
case 'constant_step_size'
if(argc == i || ~ islogical(tmparg))
error('constant_step_size argument must be followed by a logical')
else
p.constant_step_size = tmparg;
end
otherwise
if((isstr(varargin{i}) || ischar(varargin{i})))
error([ tmparg ' unrecognized argument'])
end
end
end
end
end
end
end
% experimental block end
......@@ -16,6 +16,7 @@ classdef (Abstract) ParamsFact
packing_RL
norm2_max_iter
norm2_threshold
use_MHTP
end
properties (Constant, SetAccess = protected, Hidden)
DEFAULT_STEP_SIZE = 10^-16
......@@ -177,6 +178,7 @@ classdef (Abstract) ParamsFact
p.packing_RL = packing_RL;
p.norm2_max_iter = norm2_max_iter;
p.norm2_threshold = norm2_threshold;
p.use_MHTP = false; % by default no MHTP in PALM4MSA, neither in hierarchical fact.
end
function bool = is_mat_consistent(this, M)
......
......@@ -21,7 +21,7 @@ classdef ParamsPalm4MSA < matfaust.factparams.ParamsFact
%> @param 'init_lambda', real (optional) the scale scalar initial value (by default the value is one).
%> @param 'step_size', real (optional) the initial step of the PALM descent.
%> @param 'constant_step_size', real if true the step_size keeps constant along the algorithm iterations otherwise it is updated before every factor update.
%> @param 'is_verbose', boo (optional) True to enable the verbose mode.
%> @param 'is_verbose', bool (optional) True to enable the verbose mode.
%> parameter is experimental, its value shouldn't be changed.
%> @param 'norm2_max_iter', real (optional) maximum number of iterations of power iteration algorithm. Used for computing 2-norm.
%> @param 'norm2_threshold', real (optional) power iteration algorithm threshold (default to 1e-6). Used for computing 2-norm.
......@@ -99,6 +99,18 @@ classdef ParamsPalm4MSA < matfaust.factparams.ParamsFact
end
% put mex_constraints in a cell array again because mex eats one level of array
mex_params = struct('data', M, 'nfacts', this.num_facts, 'cons', {mex_constraints}, 'init_facts', {this.init_facts}, 'niter', this.stop_crit.num_its, 'sc_is_criterion_error', this.stop_crit.is_criterion_error, 'sc_error_treshold', this.stop_crit.tol, 'sc_max_num_its', this.stop_crit.maxiter, 'update_way', this.is_update_way_R2L, 'grad_calc_opt_mode', this.grad_calc_opt_mode, 'constant_step_size', this.constant_step_size, 'step_size', this.step_size, 'verbose', this.is_verbose, 'norm2_max_iter', this.norm2_max_iter, 'norm2_threshold', this.norm2_threshold, 'init_lambda', this.init_lambda, 'use_csr', this.use_csr, 'packing_RL', this.packing_RL);
if(~ (islogical(this.use_MHTP) && this.use_MHTP == false))
% use_MHTP must be a MHTPParams if not false (cf. ParamsFact)
if(~ isa(this.use_MHTP, 'matfaust.factparams.MHTPParams'))
error('use_MHTP is not a MHTPParams')
end
mhtp_p = this.use_MHTP;
mex_params.mhtp_num_its = mhtp_p.num_its;
mex_params.mhtp_constant_step_size = mhtp_p.constant_step_size;
mex_params.mhtp_step_size = mhtp_p.step_size;
mex_params.mhtp_palm4msa_period = mhtp_p.palm4msa_period;
mex_params.mhtp_updating_lambda = mhtp_p.updating_lambda;
end
end
end
methods
......
......@@ -78,6 +78,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
system("sleep 7");
#endif
std::cout << "mexPALM4MSA2020" << std::endl;
if (nrhs != 1)
{
......@@ -98,6 +99,9 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
// // creation des parametres
auto params = mxArray2FaustParamsPALM4MSA<SCALAR,FPP2>(prhs[0], presentFields);
Faust::MHTPParams<SCALAR> mhtp_params;
mxArray2FaustMHTPParams<SCALAR>(prhs[0], mhtp_params);
std::cout << mhtp_params.to_string() << std::endl;
if(params->isVerbose) params->Display();
// Faust::BlasHandle<Cpu> blas_handle;
SCALAR lambda = params->init_lambda;
......@@ -118,8 +122,8 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
//
}
palm4msa2(params->data, noconst_cons, *F, lambda, params->stop_crit, params->isUpdateWayR2L,
params->use_csr, params->packing_RL, /* compute_2norm_on_array */ false, params->norm2_threshold,
params->norm2_max_iter, params->isConstantStepSize, params->step_size);
params->use_csr, params->packing_RL, mhtp_params, /* compute_2norm_on_array */ false, params->norm2_threshold,
params->norm2_max_iter, params->isConstantStepSize, params->step_size, false /* on_gpu*/, params->isVerbose);
Faust::MatDense<FPP2,Cpu> mat1x1Lambda = Faust::MatDense<FPP2, Cpu>(&lambda, 1, 1);
plhs[0] = FaustMat2mxArray(mat1x1Lambda);
F_lambda = F->multiply(lambda);
......
......@@ -48,6 +48,7 @@
#include "faust_constant.h"
#include "faust_Params.h"
#include "faust_ParamsPalm.h"
#include "faust_MHTP.h"
#include <complex>
#include <string>
......@@ -203,6 +204,9 @@ void testCoherencePALM4MSA(const mxArray* params,std::vector<bool> & presentFiel
template<typename SCALAR, typename FPP2>
const Faust::ParamsPalm<SCALAR,Cpu,FPP2>* mxArray2FaustParamsPALM4MSA(const mxArray* matlab_params, std::vector<bool>& presentFields);
template<typename SCALAR>
void mxArray2FaustMHTPParams(const mxArray* matlab_params, Faust::MHTPParams<SCALAR>& params);
#include "mx2Faust.hpp"
......
......@@ -964,6 +964,39 @@ const ParamsPalm<SCALAR,Cpu,FPP2>* mxArray2FaustParamsPALM4MSA(const mxArray* ma
return params;
}
template<typename SCALAR>
void mxArray2FaustMHTPParams(const mxArray* matlab_params, Faust::MHTPParams<SCALAR>& params)
{
// all fields are optional
mxArray *mx_field = mxGetField(matlab_params, 0, "mhtp_num_its");
std::cout << "mxArray2FaustMHTPParams" << std::endl;
if(params.used = (mx_field != nullptr))
params.sc = Faust::StoppingCriterion<SCALAR>((int) mxGetScalar(mx_field));
std::cout << "mxArray2FaustMHTPParams 1" << std::endl;
mx_field = mxGetField(matlab_params, 0, "mhtp_constant_step_size");
if(params.used = (params.used || (mx_field != nullptr)))
params.constant_step_size = (bool) mxGetScalar(mx_field);
std::cout << "mxArray2FaustMHTPParams 2" << std::endl;
mx_field = mxGetField(matlab_params, 0, "mhtp_step_size");
if(params.used = (params.used || (mx_field != nullptr)))
params.step_size = (Real<SCALAR>) mxGetScalar(mx_field);
std::cout << "mxArray2FaustMHTPParams 3" << std::endl;
mx_field = mxGetField(matlab_params, 0, "mhtp_palm4msa_period");
if(params.used = (params.used || (mx_field != nullptr)))
params.palm4msa_period = (int) mxGetScalar(mx_field);
std::cout << "mxArray2FaustMHTPParams 4" << std::endl;
mx_field = mxGetField(matlab_params, 0, "mhtp_updating_lambda");
if(params.used = (params.used || (mx_field != nullptr)))
params.updating_lambda = (bool) mxGetScalar(mx_field);
std::cout << "mxArray2FaustMHTPParams 5" << std::endl;
std::cout << "params.used" << params.used << std::endl;
}
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment