Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 9d1bb8f9 authored by hhakim's avatar hhakim
Browse files

Handle use_csr and packing_RL parameters in matfaust.fact.palm4msa (as it is...

Handle use_csr and packing_RL parameters in matfaust.fact.palm4msa (as it is already in hiearchical) for the 2020 impl.

- mxArray2FaustParamsPALM4MSA update.
- matfaust.factparams.ParamsPalm4MSA.to_mex_struct.
- C++ struct src/algorithm/factorization/faust_ParamsPalm.*.
parent c2784a75
Branches
Tags
No related merge requests found
...@@ -107,6 +107,8 @@ namespace Faust ...@@ -107,6 +107,8 @@ namespace Faust
GradientCalcOptMode gradCalcOptMode; GradientCalcOptMode gradCalcOptMode;
Real<FPP> norm2_threshold; Real<FPP> norm2_threshold;
unsigned int norm2_max_iter; unsigned int norm2_max_iter;
bool use_csr;
bool packing_RL;
void Display() const; void Display() const;
void init_factors(); void init_factors();
......
...@@ -168,6 +168,8 @@ void Faust::ParamsPalm<FPP,DEVICE,FPP2>::Display() const ...@@ -168,6 +168,8 @@ void Faust::ParamsPalm<FPP,DEVICE,FPP2>::Display() const
std::cout << "gradCalcOptMode: "<< gradCalcOptMode << std::endl; std::cout << "gradCalcOptMode: "<< gradCalcOptMode << std::endl;
std::cout << "norm2_threshold:" << norm2_threshold << std::endl; std::cout << "norm2_threshold:" << norm2_threshold << std::endl;
std::cout << "norm2_max_iter:" << norm2_max_iter << std::endl; std::cout << "norm2_max_iter:" << norm2_max_iter << std::endl;
std::cout << "use_csr:" << use_csr << std::endl;
std::cout << "packing_RL:" << packing_RL << std::endl;
/*cout<<"INIT_FACTS :"<<endl; /*cout<<"INIT_FACTS :"<<endl;
for (int L=0;L<init_fact.size();L++)init_fact[L].Display();*/ for (int L=0;L<init_fact.size();L++)init_fact[L].Display();*/
......
...@@ -50,7 +50,7 @@ classdef (Abstract) ParamsFact ...@@ -50,7 +50,7 @@ classdef (Abstract) ParamsFact
%> default is EXTERNAL_OPT %> default is EXTERNAL_OPT
DEFAULT_OPT = 2 DEFAULT_OPT = 2
% the order of names matters and must respect the indices above % the order of names matters and must respect the indices above
OPT_ARG_NAMES = {'is_update_way_R2L', 'init_lambda', 'step_size', 'constant_step_size', 'is_verbose', 'grad_calc_opt_mode', 'norm2_max_iter', 'norm2_threshold', 'packing_RL', 'use_csr' } OPT_ARG_NAMES = {'is_update_way_R2L', 'init_lambda', 'step_size', 'constant_step_size', 'is_verbose', 'grad_calc_opt_mode', 'norm2_max_iter', 'norm2_threshold', 'use_csr', 'packing_RL'}
end end
methods methods
function p = ParamsFact(num_facts, constraints, varargin) function p = ParamsFact(num_facts, constraints, varargin)
......
...@@ -95,8 +95,7 @@ classdef ParamsPalm4MSA < matfaust.factparams.ParamsFact ...@@ -95,8 +95,7 @@ classdef ParamsPalm4MSA < matfaust.factparams.ParamsFact
mex_constraints{i} = cur_cell; mex_constraints{i} = cur_cell;
end end
% put mex_constraints in a cell array again because mex eats one level of array % put mex_constraints in a cell array again because mex eats one level of array
mex_params = struct('data', M, 'nfacts', this.num_facts, 'cons', {mex_constraints}, 'init_facts', {this.init_facts}, 'niter', this.stop_crit.num_its, 'sc_is_criterion_error', this.stop_crit.is_criterion_error, 'sc_error_treshold', this.stop_crit.tol, 'sc_max_num_its', this.stop_crit.maxiter, 'update_way', this.is_update_way_R2L, 'grad_calc_opt_mode', this.grad_calc_opt_mode, 'constant_step_size', this.constant_step_size, 'step_size', this.step_size, 'verbose', this.is_verbose, 'norm2_max_iter', this.norm2_max_iter, 'norm2_threshold', this.norm2_threshold, 'init_lambda', this.init_lambda); mex_params = struct('data', M, 'nfacts', this.num_facts, 'cons', {mex_constraints}, 'init_facts', {this.init_facts}, 'niter', this.stop_crit.num_its, 'sc_is_criterion_error', this.stop_crit.is_criterion_error, 'sc_error_treshold', this.stop_crit.tol, 'sc_max_num_its', this.stop_crit.maxiter, 'update_way', this.is_update_way_R2L, 'grad_calc_opt_mode', this.grad_calc_opt_mode, 'constant_step_size', this.constant_step_size, 'step_size', this.step_size, 'verbose', this.is_verbose, 'norm2_max_iter', this.norm2_max_iter, 'norm2_threshold', this.norm2_threshold, 'init_lambda', this.init_lambda, 'use_csr', this.use_csr, 'packing_RL', this.packing_RL);
end end
end end
methods methods
......
...@@ -67,29 +67,29 @@ mxArray* FaustMat2mxArray(const Faust::MatDense<FPP,Cpu>& M) ...@@ -67,29 +67,29 @@ mxArray* FaustMat2mxArray(const Faust::MatDense<FPP,Cpu>& M)
{ {
if (!M.isReal()) if (!M.isReal())
mexErrMsgTxt("FaustMat2mxArray : Faust::MatDense must be real"); mexErrMsgTxt("FaustMat2mxArray : Faust::MatDense must be real");
mxArray * mxMat; mxArray * mxMat;
int row,col; int row,col;
row = M.getNbRow(); row = M.getNbRow();
col = M.getNbCol(); col = M.getNbCol();
/*mxMat = mxCreateDoubleMatrix(row,col,mxREAL); /*mxMat = mxCreateDoubleMatrix(row,col,mxREAL);
double* mxMatdata = mxGetPr(mxMat); double* mxMatdata = mxGetPr(mxMat);
if (typeid(double) == typeid(FPP)) if (typeid(double) == typeid(FPP))
memcpy(mxMatdata,M.getData(),sizeof(double)*row*col); memcpy(mxMatdata,M.getData(),sizeof(double)*row*col);
else else
{ {
double* mat_ptr = (double *) mxCalloc(row*col,sizeof(double)); double* mat_ptr = (double *) mxCalloc(row*col,sizeof(double));
for (int i=0;i<row*col;i++) for (int i=0;i<row*col;i++)
{ {
mat_ptr[i] = (double) M.getData()[i]; mat_ptr[i] = (double) M.getData()[i];
} }
memcpy(mxMatdata,mat_ptr,sizeof(double)*row*col); memcpy(mxMatdata,mat_ptr,sizeof(double)*row*col);
mxFree(mat_ptr); mxFree(mat_ptr);
}*/ }*/
const mwSize dims[3]={(mwSize)row,(mwSize)col}; const mwSize dims[3]={(mwSize)row,(mwSize)col};
if(typeid(FPP)==typeid(float)) if(typeid(FPP)==typeid(float))
{ {
...@@ -101,13 +101,13 @@ mxArray* FaustMat2mxArray(const Faust::MatDense<FPP,Cpu>& M) ...@@ -101,13 +101,13 @@ mxArray* FaustMat2mxArray(const Faust::MatDense<FPP,Cpu>& M)
{ {
mexErrMsgTxt("FaustMat2mxArray : unsupported type of float"); mexErrMsgTxt("FaustMat2mxArray : unsupported type of float");
} }
FPP* ptr_out = static_cast<FPP*> (mxGetData(mxMat)); FPP* ptr_out = static_cast<FPP*> (mxGetData(mxMat));
memcpy(ptr_out, M.getData(),row*col*sizeof(FPP)); memcpy(ptr_out, M.getData(),row*col*sizeof(FPP));
return mxMat; return mxMat;
} }
......
...@@ -151,8 +151,8 @@ void testCoherencePALM4MSA(const mxArray* params,std::vector<bool> & presentFiel ...@@ -151,8 +151,8 @@ void testCoherencePALM4MSA(const mxArray* params,std::vector<bool> & presentFiel
{ {
////TODO: this function should be modified to be more reliable/simple as the function testCoherence() in mx2Faust.cpp has been modified ////TODO: this function should be modified to be more reliable/simple as the function testCoherence() in mx2Faust.cpp has been modified
int nbr_field=mxGetNumberOfFields(params); int nbr_field=mxGetNumberOfFields(params);
presentFields.resize(16); presentFields.resize(18);
presentFields.assign(16,false); presentFields.assign(18,false);
if(nbr_field < 3) if(nbr_field < 3)
{ {
mexErrMsgTxt("The number of field of params must be at least 3 "); mexErrMsgTxt("The number of field of params must be at least 3 ");
...@@ -227,6 +227,10 @@ void testCoherencePALM4MSA(const mxArray* params,std::vector<bool> & presentFiel ...@@ -227,6 +227,10 @@ void testCoherencePALM4MSA(const mxArray* params,std::vector<bool> & presentFiel
presentFields[14] = true; presentFields[14] = true;
else if(strcmp(fieldName, "norm2_threshold") == 0) else if(strcmp(fieldName, "norm2_threshold") == 0)
presentFields[15] = true; presentFields[15] = true;
else if(strcmp(fieldName, "use_csr") == 0)
presentFields[16] = true;
else if(strcmp(fieldName, "packing_RL") == 0)
presentFields[17] = true;
} }
} }
......
...@@ -966,7 +966,6 @@ const ParamsPalm<SCALAR,Cpu,FPP2>* mxArray2FaustParamsPALM4MSA(const mxArray* ma ...@@ -966,7 +966,6 @@ const ParamsPalm<SCALAR,Cpu,FPP2>* mxArray2FaustParamsPALM4MSA(const mxArray* ma
mexErrMsgTxt("init_facts must be must be specified"); mexErrMsgTxt("init_facts must be must be specified");
} }
// std::cout<<"PASSER1"<<std::endl;
//verbosity //verbosity
bool isVerbose = false; bool isVerbose = false;
if (presentFields[5]) if (presentFields[5])
...@@ -1025,6 +1024,20 @@ const ParamsPalm<SCALAR,Cpu,FPP2>* mxArray2FaustParamsPALM4MSA(const mxArray* ma ...@@ -1025,6 +1024,20 @@ const ParamsPalm<SCALAR,Cpu,FPP2>* mxArray2FaustParamsPALM4MSA(const mxArray* ma
mxCurrentField = mxGetField(matlab_params, 0, "norm2_threshold"); mxCurrentField = mxGetField(matlab_params, 0, "norm2_threshold");
norm2_threshold = (FPP2) mxGetScalar(mxCurrentField); norm2_threshold = (FPP2) mxGetScalar(mxCurrentField);
} }
bool use_csr = Params<SCALAR, Cpu, FPP2>::defaultUseCSR;
bool packing_RL = Params<SCALAR, Cpu, FPP2>::defaultPackingRL;
if(presentFields[16])
{
mxCurrentField = mxGetField(matlab_params, 0, "use_csr");
use_csr = (bool) mxGetScalar(mxCurrentField);
std::cout << "mx2Faust use_csr:" << use_csr << std::endl;
}
if(presentFields[17])
{
mxCurrentField = mxGetField(matlab_params, 0, "packing_RL");
packing_RL = (bool) mxGetScalar(mxCurrentField);
std::cout << "mx2Faust packing_RL:" << packing_RL << std::endl;
}
//compute_lambda //compute_lambda
// bool compute_lambda = true; // bool compute_lambda = true;
// if (presentFields[8]) // if (presentFields[8])
...@@ -1033,11 +1046,22 @@ const ParamsPalm<SCALAR,Cpu,FPP2>* mxArray2FaustParamsPALM4MSA(const mxArray* ma ...@@ -1033,11 +1046,22 @@ const ParamsPalm<SCALAR,Cpu,FPP2>* mxArray2FaustParamsPALM4MSA(const mxArray* ma
// compute_lambda = (bool) mxGetScalar(mxCurrentField); // compute_lambda = (bool) mxGetScalar(mxCurrentField);
// } // }
params = new Faust::ParamsPalm<SCALAR,Cpu, FPP2>(data,nbFact,consS,init_facts,crit1,isVerbose,updateway,init_lambda, constant_step_size, step_size, grad_calc_opt_mode); params = new Faust::ParamsPalm<SCALAR,Cpu, FPP2>(data,
nbFact,
consS,
init_facts,
crit1,
isVerbose,
updateway,
init_lambda,
constant_step_size,
step_size,
grad_calc_opt_mode);
if(norm2_max_iter) params->norm2_max_iter = norm2_max_iter; if(norm2_max_iter) params->norm2_max_iter = norm2_max_iter;
if(norm2_threshold != FPP2(0)) params->norm2_threshold = norm2_threshold; if(norm2_threshold != FPP2(0)) params->norm2_threshold = norm2_threshold;
params->use_csr = use_csr;
params->packing_RL = packing_RL;
return params; return params;
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment