Mentions légales du service

Skip to content
Snippets Groups Projects
Commit f772b58b authored by hhakim's avatar hhakim
Browse files

Add matfaust.fact.hierarchical_mhtp (mex, matlab function and API doc).

Other related minor changes.
parent dd50067a
No related branches found
No related tags found
No related merge requests found
Showing with 67 additions and 18 deletions
...@@ -31,7 +31,7 @@ if(BUILD_DOCUMENTATION) ...@@ -31,7 +31,7 @@ if(BUILD_DOCUMENTATION)
string(CONCAT DOXYGEN_FILE_PATTERNS "*.cpp *.hpp *.h *.cu *.hu") string(CONCAT DOXYGEN_FILE_PATTERNS "*.cpp *.hpp *.h *.cu *.hu")
endif() endif()
if(BUILD_WRAPPER_MATLAB) if(BUILD_WRAPPER_MATLAB)
string(CONCAT DOXYGEN_FILE_PATTERNS ${DOXYGEN_FILE_PATTERNS} " Faust.m FaustMulMode.m StoppingCriterion.m ConstraintGeneric.m ConstraintMat.m ConstraintReal.m ConstraintInt.m ConstraintName.m ParamsFact.m ParamsHierarchical.m ParamsPalm4MSA.m FaustFactory.m hadamard.m quickstart.m fft.m bsl.m runtimecmp.m runall.m version.m faust_fact.m ParamsHierarchicalSquareMat.m ParamsHierarchicalRectMat.m license.m omp.m wht.m dft.m eye.m rand.m eigtj.m hierarchical.m fact.m palm4msa.m fgft_givens.m fgft_palm.m svdtj.m splin.m spcol.m proj_gen.m sp.m const.m supp.m hankel.m toeplitz.m circ.m normcol.m normlin.m splincol.m blockdiag.m skperm.m enable_gpu_mod.m isFaust.m poly.m basis.m next.m expm_multiply.m FaustPoly.m MHTPParams.m palm4msa_mhtp.m ") # warning: the space on the end matters string(CONCAT DOXYGEN_FILE_PATTERNS ${DOXYGEN_FILE_PATTERNS} " Faust.m FaustMulMode.m StoppingCriterion.m ConstraintGeneric.m ConstraintMat.m ConstraintReal.m ConstraintInt.m ConstraintName.m ParamsFact.m ParamsHierarchical.m ParamsPalm4MSA.m FaustFactory.m hadamard.m quickstart.m fft.m bsl.m runtimecmp.m runall.m version.m faust_fact.m ParamsHierarchicalSquareMat.m ParamsHierarchicalRectMat.m license.m omp.m wht.m dft.m eye.m rand.m eigtj.m hierarchical.m fact.m palm4msa.m fgft_givens.m fgft_palm.m svdtj.m splin.m spcol.m proj_gen.m sp.m const.m supp.m hankel.m toeplitz.m circ.m normcol.m normlin.m splincol.m blockdiag.m skperm.m enable_gpu_mod.m isFaust.m poly.m basis.m next.m expm_multiply.m FaustPoly.m MHTPParams.m palm4msa_mhtp.m hierarchical_mhtp.m ") # warning: the space on the end matters
endif() endif()
if(BUILD_WRAPPER_PYTHON) if(BUILD_WRAPPER_PYTHON)
string(CONCAT DOXYGEN_FILE_PATTERNS ${DOXYGEN_FILE_PATTERNS} "__init__.py factparams.py demo.py tools.py fact.py proj.py poly.py") string(CONCAT DOXYGEN_FILE_PATTERNS ${DOXYGEN_FILE_PATTERNS} "__init__.py factparams.py demo.py tools.py fact.py proj.py poly.py")
......
...@@ -83,7 +83,8 @@ namespace Faust ...@@ -83,7 +83,8 @@ namespace Faust
const FPP2 init_lambda_ = defaultLambda, const FPP2 init_lambda_ = defaultLambda,
const bool constant_step_size_ = defaultConstantStepSize, const bool constant_step_size_ = defaultConstantStepSize,
const FPP2 step_size_ = defaultStepSize, const FPP2 step_size_ = defaultStepSize,
const GradientCalcOptMode gradCalcOptMode = Params<FPP,DEVICE,FPP2>::defaultGradCalcOptMode); const GradientCalcOptMode gradCalcOptMode = Params<FPP,DEVICE,FPP2>::defaultGradCalcOptMode,
const bool use_MHTP = Params<FPP,DEVICE, FPP2>::defaultUseMHTP);
void check_constraint_validity(); void check_constraint_validity();
ParamsPalm(); ParamsPalm();
......
...@@ -93,7 +93,7 @@ Faust::ParamsPalm<FPP,DEVICE,FPP2>::ParamsPalm( ...@@ -93,7 +93,7 @@ Faust::ParamsPalm<FPP,DEVICE,FPP2>::ParamsPalm(
const FPP2 init_lambda_ /* = 1.0 */, const FPP2 init_lambda_ /* = 1.0 */,
const bool constant_step_size_, const bool constant_step_size_,
const FPP2 step_size_, const FPP2 step_size_,
const GradientCalcOptMode gradCalcOptMode /* default INTERNAL_OPT*/) : const GradientCalcOptMode gradCalcOptMode /* default INTERNAL_OPT*/, const bool use_MHTP/*= Params<FPP,DEVICE, FPP2>::defaultUseMHTP*/) :
data(data_), data(data_),
nbFact(nbFact_), nbFact(nbFact_),
cons(cons_), cons(cons_),
...@@ -106,7 +106,8 @@ Faust::ParamsPalm<FPP,DEVICE,FPP2>::ParamsPalm( ...@@ -106,7 +106,8 @@ Faust::ParamsPalm<FPP,DEVICE,FPP2>::ParamsPalm(
init_lambda(init_lambda_), init_lambda(init_lambda_),
gradCalcOptMode(gradCalcOptMode), gradCalcOptMode(gradCalcOptMode),
norm2_threshold(FAUST_PRECISION), norm2_threshold(FAUST_PRECISION),
norm2_max_iter(FAUST_NORM2_MAX_ITER) norm2_max_iter(FAUST_NORM2_MAX_ITER),
use_MHTP(use_MHTP)
{ {
check_constraint_validity(); check_constraint_validity();
} }
......
...@@ -136,7 +136,7 @@ function varargout = hierarchical(M, p, varargin) ...@@ -136,7 +136,7 @@ function varargout = hierarchical(M, p, varargin)
if(~ p.is_mat_consistent(M)) if(~ p.is_mat_consistent(M))
error('M''s number of columns must be consistent with the last residuum constraint defined in p. Likewise its number of rows must be consistent with the first factor constraint defined in p.') error('M''s number of columns must be consistent with the last residuum constraint defined in p. Likewise its number of rows must be consistent with the first factor constraint defined in p.')
end end
mex_params = p.to_mex_struct(); mex_params = p.to_mex_struct()
backend = 2016; backend = 2016;
nargin = length(varargin); nargin = length(varargin);
gpu = false; gpu = false;
......
% experimental block start
%==========================================================================
%> @brief Runs the MHTP-PALM4MSA hierarchical factorization algorithm on the matrix M.
%>
%> This algorithm uses the MHTP-PALM4MSA (matfaust.fact.palm4msa_mhtp) instead of only PALM4MSA as matfaust.fact.hierarchical.
%>
%> @param M the dense matrix to factorize.
%> @param hierarchical_p is a set of factorization parameters. See matfaust.fact.hierarchical.
%> @param mhtp_p the matfaust.factparams.MHTPParams instance to define the MHTP algorithm parameters.
%> @param varargin: see matfaust.fact.hierarchical for the other parameters.
%>
%>@b Example
%>@code
%> import matfaust.fact.hierarchical_mhtp
%> import matfaust.factparams.ParamsHierarchical
%> import matfaust.factparams.StoppingCriterion
%> import matfaust.factparams.MHTPParams
%> import matfaust.proj.*
%> M = rand(500,32);
%> fact_projs = { splin([500,32], 5), sp([32,32], 96), sp([32, 32], 96)};
%> res_projs = { normcol([32,32], 1), sp([32,32], 666), sp([32, 32], 333)};
%> stop_crit1 = StoppingCriterion(200)
%> stop_crit2 = StoppingCriterion(200)
%> % 50 iterations of MHTP will run every 100 iterations of PALM4MSA (each time PALM4MSA is called by the hierarchical algorithm)
%> mhtp_param = MHTPParams('num_its', 150, 'palm4msa_period', 100)
%> param = ParamsHierarchical(fact_projs, res_projs, stop_crit1, stop_crit2)
%> F = hierarchical_mhtp(M, param, mhtp_param)
%>@endcode
%>
%==========================================================================
function [F,lambda] = hierarchical_mhtp(M, hierarchical_p, mhtp_p, varargin)
hierarchical_p.use_MHTP = mhtp_p;
[F, lambda] = matfaust.fact.hierarchical(M, hierarchical_p, varargin{:}, 'backend', 2020);
end
% experimental block end
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
%> @param palm4msa_p the matfaust.factparams.ParamsPalm4MSA instance to define the algorithm parameters. %> @param palm4msa_p the matfaust.factparams.ParamsPalm4MSA instance to define the algorithm parameters.
%> @param mthp_p the matfaust.factparams.MHTPParams instance to define the MHTP algorithm parameters. %> @param mthp_p the matfaust.factparams.MHTPParams instance to define the MHTP algorithm parameters.
%> @param 'gpu', bool (optional) set to true to execute the algorithm using the GPU implementation. This options is only available when backend==2020. %> @param 'gpu', bool (optional) set to true to execute the algorithm using the GPU implementation. This options is only available when backend==2020.
%> @param varargin: see matfaust.fact.hierarchical for the other parameters.
%> %>
%> @retval F the Faust object result of the factorization. %> @retval F the Faust object result of the factorization.
%> @retval [F, lambda] = palm4msa(M, p) to optionally get lambda (scale). %> @retval [F, lambda] = palm4msa(M, p) to optionally get lambda (scale).
......
...@@ -158,6 +158,18 @@ classdef ParamsHierarchical < matfaust.factparams.ParamsFact ...@@ -158,6 +158,18 @@ classdef ParamsHierarchical < matfaust.factparams.ParamsFact
mex_constraints{2,i} = cur_cell; mex_constraints{2,i} = cur_cell;
end end
mex_params = struct('nfacts', this.num_facts, 'cons', {mex_constraints}, 'niter1', this.stop_crits{1}.num_its,'niter2', this.stop_crits{2}.num_its, 'sc_is_criterion_error', this.stop_crits{1}.is_criterion_error, 'sc_error_treshold', this.stop_crits{1}.tol, 'sc_max_num_its', this.stop_crits{1}.maxiter, 'sc_is_criterion_error2', this.stop_crits{2}.is_criterion_error, 'sc_error_treshold2', this.stop_crits{2}.tol, 'sc_max_num_its2', this.stop_crits{2}.maxiter, 'nrow', this.data_num_rows, 'ncol', this.data_num_cols, 'fact_side', this.is_fact_side_left, 'update_way', this.is_update_way_R2L, 'verbose', this.is_verbose, 'init_lambda', this.init_lambda, 'use_csr', this.use_csr, 'packing_RL', this.packing_RL, 'norm2_threshold', this.norm2_threshold, 'norm2_max_iter', this.norm2_max_iter, 'step_size', this.step_size, 'constant_step_size', this.constant_step_size); mex_params = struct('nfacts', this.num_facts, 'cons', {mex_constraints}, 'niter1', this.stop_crits{1}.num_its,'niter2', this.stop_crits{2}.num_its, 'sc_is_criterion_error', this.stop_crits{1}.is_criterion_error, 'sc_error_treshold', this.stop_crits{1}.tol, 'sc_max_num_its', this.stop_crits{1}.maxiter, 'sc_is_criterion_error2', this.stop_crits{2}.is_criterion_error, 'sc_error_treshold2', this.stop_crits{2}.tol, 'sc_max_num_its2', this.stop_crits{2}.maxiter, 'nrow', this.data_num_rows, 'ncol', this.data_num_cols, 'fact_side', this.is_fact_side_left, 'update_way', this.is_update_way_R2L, 'verbose', this.is_verbose, 'init_lambda', this.init_lambda, 'use_csr', this.use_csr, 'packing_RL', this.packing_RL, 'norm2_threshold', this.norm2_threshold, 'norm2_max_iter', this.norm2_max_iter, 'step_size', this.step_size, 'constant_step_size', this.constant_step_size);
if(~ (islogical(this.use_MHTP) && this.use_MHTP == false))
% use_MHTP must be a MHTPParams if not false (cf. ParamsFact)
if(~ isa(this.use_MHTP, 'matfaust.factparams.MHTPParams'))
error('use_MHTP is not a MHTPParams')
end
mhtp_p = this.use_MHTP;
mex_params.mhtp_num_its = mhtp_p.num_its;
mex_params.mhtp_constant_step_size = mhtp_p.constant_step_size;
mex_params.mhtp_step_size = mhtp_p.step_size;
mex_params.mhtp_palm4msa_period = mhtp_p.palm4msa_period;
mex_params.mhtp_updating_lambda = mhtp_p.updating_lambda;
end
end end
end end
......
...@@ -87,7 +87,8 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) ...@@ -87,7 +87,8 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
// initialization of the matrix that will be factorized // initialization of the matrix that will be factorized
Faust::MatDense<SCALAR,Cpu> matrix; Faust::MatDense<SCALAR,Cpu> matrix;
mxArray2FaustMat(matlab_matrix,matrix); mxArray2FaustMat(matlab_matrix,matrix);
Faust::MHTPParams<SCALAR> mhtp_params;
mxArray2FaustMHTPParams<SCALAR>(matlab_params, mhtp_params);
auto *params = mxArray2FaustParams<SCALAR,FPP2>(matlab_params); auto *params = mxArray2FaustParams<SCALAR,FPP2>(matlab_params);
FPP2 lambda = params->init_lambda; FPP2 lambda = params->init_lambda;
...@@ -98,7 +99,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) ...@@ -98,7 +99,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
auto res_cons = params->cons[1]; auto res_cons = params->cons[1];
bool compute_2norm_on_arrays = false; bool compute_2norm_on_arrays = false;
std::vector<Faust::StoppingCriterion<Real<SCALAR>>> sc = {params->stop_crit_2facts.get_crit(), params->stop_crit_global.get_crit()}; std::vector<Faust::StoppingCriterion<Real<SCALAR>>> sc = {params->stop_crit_2facts.get_crit(), params->stop_crit_global.get_crit()};
auto th = Faust::hierarchical(matrix, sc, fac_cons, res_cons, lambda, params->isUpdateWayR2L, params->isFactSideLeft, params->use_csr, params->packing_RL, compute_2norm_on_arrays, params->norm2_threshold, params->norm2_max_iter, params->isVerbose, params->isConstantStepSize, params->step_size); auto th = Faust::hierarchical(matrix, sc, fac_cons, res_cons, lambda, params->isUpdateWayR2L, params->isFactSideLeft, params->use_csr, params->packing_RL, mhtp_params, compute_2norm_on_arrays, params->norm2_threshold, params->norm2_max_iter, params->isVerbose, params->isConstantStepSize, params->step_size);
auto th_times_lambda = th->multiply(lambda); auto th_times_lambda = th->multiply(lambda);
delete th; delete th;
......
...@@ -101,7 +101,6 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) ...@@ -101,7 +101,6 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
auto params = mxArray2FaustParamsPALM4MSA<SCALAR,FPP2>(prhs[0], presentFields); auto params = mxArray2FaustParamsPALM4MSA<SCALAR,FPP2>(prhs[0], presentFields);
Faust::MHTPParams<SCALAR> mhtp_params; Faust::MHTPParams<SCALAR> mhtp_params;
mxArray2FaustMHTPParams<SCALAR>(prhs[0], mhtp_params); mxArray2FaustMHTPParams<SCALAR>(prhs[0], mhtp_params);
std::cout << mhtp_params.to_string() << std::endl;
if(params->isVerbose) params->Display(); if(params->isVerbose) params->Display();
// Faust::BlasHandle<Cpu> blas_handle; // Faust::BlasHandle<Cpu> blas_handle;
SCALAR lambda = params->init_lambda; SCALAR lambda = params->init_lambda;
......
...@@ -143,7 +143,12 @@ void testCoherence(const mxArray* params,std::vector<bool> & presentFields) ...@@ -143,7 +143,12 @@ void testCoherence(const mxArray* params,std::vector<bool> & presentFields)
mexErrMsgTxt("The number of fields in params must be at least 3 "); mexErrMsgTxt("The number of fields in params must be at least 3 ");
for(int i=0;i<nbr_field;i++) for(int i=0;i<nbr_field;i++)
presentFields[mat_field_str2type(string(mxGetFieldNameByNumber(params,i)))] = true; {
auto field_name = string(mxGetFieldNameByNumber(params,i));
// dirty bypass to tolerate fields starting with mhtp (MHTPParams) parsed by mxArray2FaustMHTPParams
if (field_name.rfind("mhtp_", 0) != 0)
presentFields[mat_field_str2type(field_name)] = true;
}
} }
......
...@@ -969,34 +969,27 @@ void mxArray2FaustMHTPParams(const mxArray* matlab_params, Faust::MHTPParams<SCA ...@@ -969,34 +969,27 @@ void mxArray2FaustMHTPParams(const mxArray* matlab_params, Faust::MHTPParams<SCA
{ {
// all fields are optional // all fields are optional
mxArray *mx_field = mxGetField(matlab_params, 0, "mhtp_num_its"); mxArray *mx_field = mxGetField(matlab_params, 0, "mhtp_num_its");
std::cout << "mxArray2FaustMHTPParams" << std::endl; std::cout << "mx_field:" << mx_field << std::endl;
if(params.used = (mx_field != nullptr)) if(params.used = (mx_field != nullptr))
params.sc = Faust::StoppingCriterion<SCALAR>((int) mxGetScalar(mx_field)); params.sc = Faust::StoppingCriterion<SCALAR>((int) mxGetScalar(mx_field));
std::cout << "mxArray2FaustMHTPParams 1" << std::endl;
mx_field = mxGetField(matlab_params, 0, "mhtp_constant_step_size"); mx_field = mxGetField(matlab_params, 0, "mhtp_constant_step_size");
if(params.used = (params.used || (mx_field != nullptr))) if(params.used = (params.used || (mx_field != nullptr)))
params.constant_step_size = (bool) mxGetScalar(mx_field); params.constant_step_size = (bool) mxGetScalar(mx_field);
std::cout << "mxArray2FaustMHTPParams 2" << std::endl;
mx_field = mxGetField(matlab_params, 0, "mhtp_step_size"); mx_field = mxGetField(matlab_params, 0, "mhtp_step_size");
if(params.used = (params.used || (mx_field != nullptr))) if(params.used = (params.used || (mx_field != nullptr)))
params.step_size = (Real<SCALAR>) mxGetScalar(mx_field); params.step_size = (Real<SCALAR>) mxGetScalar(mx_field);
std::cout << "mxArray2FaustMHTPParams 3" << std::endl;
mx_field = mxGetField(matlab_params, 0, "mhtp_palm4msa_period"); mx_field = mxGetField(matlab_params, 0, "mhtp_palm4msa_period");
if(params.used = (params.used || (mx_field != nullptr))) if(params.used = (params.used || (mx_field != nullptr)))
params.palm4msa_period = (int) mxGetScalar(mx_field); params.palm4msa_period = (int) mxGetScalar(mx_field);
std::cout << "mxArray2FaustMHTPParams 4" << std::endl;
mx_field = mxGetField(matlab_params, 0, "mhtp_updating_lambda"); mx_field = mxGetField(matlab_params, 0, "mhtp_updating_lambda");
if(params.used = (params.used || (mx_field != nullptr))) if(params.used = (params.used || (mx_field != nullptr)))
params.updating_lambda = (bool) mxGetScalar(mx_field); params.updating_lambda = (bool) mxGetScalar(mx_field);
std::cout << "mxArray2FaustMHTPParams 5" << std::endl;
std::cout << "params.used" << params.used << std::endl;
} }
#endif #endif
...@@ -547,7 +547,8 @@ def hierarchical_mhtp(M, hierar_p, mhtp_p, ret_lambda=False, ret_params=False, ...@@ -547,7 +547,8 @@ def hierarchical_mhtp(M, hierar_p, mhtp_p, ret_lambda=False, ret_params=False,
Args: Args:
M: the numpy array to factorize. M: the numpy array to factorize.
p: is a set of hierarchical factorization parameters. See pyfaust.fact.hierarchical. p: is a set of hierarchical factorization parameters. See pyfaust.fact.hierarchical.
on_gpu: if True the GPU implementation is executed (this option applies only to 2020 backend). mhtp_p: the pyfaust.factparams.MHTPParams instance to define the MHTP algorithm parameters.
on_gpu: if True the GPU implementation is executed.
ret_lambda: set to True to ask the function to return the scale factor (False by default). ret_lambda: set to True to ask the function to return the scale factor (False by default).
ret_params: set to True to ask the function to return the ret_params: set to True to ask the function to return the
ParamsHierarchical instance used (False by default). ParamsHierarchical instance used (False by default).
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment