Mentions légales du service

Skip to content
Snippets Groups Projects
Commit c8b4bc5c authored by hhakim's avatar hhakim
Browse files

Auto-infer matrix dimension sizes from constraints in...

Auto-infer matrix dimension sizes from constraints in matfaust.factparams.ParamsHierarchicalFact and secure the type of p argument passed to FaustFactory.fact_hierarchical() and FaustFactory.fact_palm4msa().

The inference saves 2 arguments in FaustFactory.fact_hierarchical() (simplification).
parent 95936673
Branches
Tags 3.35.6
No related merge requests found
......@@ -105,7 +105,7 @@ classdef FaustFactoryTest < matlab.unittest.TestCase
stop_crit = StoppingCriterion(200);
stop_crit2 = StoppingCriterion(200);
params = ParamsHierarchicalFact(num_facts, is_update_way_R2L, init_lambda,...
fact_cons, res_cons, size(M,1), size(M,2), {stop_crit, stop_crit2});
fact_cons, res_cons, {stop_crit, stop_crit2});
F = FaustFactory.fact_hierarchical(M, params)
this.verifyEqual(size(F), size(M))
%disp('norm F: ')
......@@ -145,7 +145,7 @@ classdef FaustFactoryTest < matlab.unittest.TestCase
stop_crit = StoppingCriterion(200);
stop_crit2 = StoppingCriterion(200);
params = ParamsHierarchicalFact(num_facts, is_update_way_R2L, init_lambda,...
fact_cons, res_cons, size(M,1), size(M,2), {stop_crit, stop_crit2});
fact_cons, res_cons, {stop_crit, stop_crit2});
F = FaustFactory.fact_hierarchical(M, params)
this.verifyEqual(size(F), size(M))
%disp('norm F: ')
......
......@@ -77,5 +77,14 @@ classdef (Abstract) ParamsFact
p.is_verbose = is_verbose;
p.constant_step_size = constant_step_size;
end
function bool = is_mat_consistent(this, M)
if(~ ismatrix(M))
error('M must be a matrix.')
else
s = size(M);
bool = s(1) == this.constraints{1}.num_rows && s(2) == this.constraints{end}.num_cols;
end
end
end
end
......@@ -12,7 +12,7 @@ classdef ParamsHierarchicalFact < matfaust.factparams.ParamsFact
end
methods
function p = ParamsHierarchicalFact(varargin)
MIN_NARGIN = 8;
MIN_NARGIN = 6;
if(nargin < MIN_NARGIN)
error(['matfaust.factparams.ParamsHierarchicalFact() must receive at least',int2str(MIN_NARGIN),' arguments'])
end
......@@ -22,9 +22,8 @@ classdef ParamsHierarchicalFact < matfaust.factparams.ParamsFact
fact_constraints = varargin{4};
res_constraints = varargin{5};
constraints = {fact_constraints{:}, res_constraints{:}};
data_num_rows = floor(varargin{6});
data_num_cols = floor(varargin{7});
stop_crits = varargin{8};
% data_num_rows/data_num_cols are set by FaustFactory.fact_hierarchical()
stop_crits = varargin{6};
% set default values
is_fact_side_left = matfaust.factparams.ParamsHierarchicalFact.DEFAULT_IS_FACT_SIDE_LEFT;
step_size = matfaust.factparams.ParamsFact.DEFAULT_STEP_SIZE;
......@@ -45,35 +44,28 @@ classdef ParamsHierarchicalFact < matfaust.factparams.ParamsFact
% parent constructor handles verification for its own arguments
p = p@matfaust.factparams.ParamsFact(num_facts, is_update_way_R2L, init_lambda, ...
constraints, step_size, constant_step_size, is_verbose);
if(~ isscalar(data_num_rows) || ~ isreal(data_num_rows))
error('matfaust.factparams.ParamsHierarchicalFact 6th argument (data_num_rows) must be an integer.')
else
data_num_rows = floor(data_num_rows);
end
if(~ isscalar(data_num_cols) || ~ isreal(data_num_cols))
error('matfaust.factparams.ParamsHierarchicalFact 7th argument (data_num_cols) must be an integer.')
else
data_num_cols = floor(data_num_cols);
end
if(~ iscell(stop_crits))
error('matfaust.factparams.ParamsHierarchicalFact 8th argument (stop_crits) must be a cell array.')
error('matfaust.factparams.ParamsHierarchicalFact 6th argument (stop_crits) must be a cell array.')
if(length(stop_crits) ~= 2 )
error('matfaust.factparams.ParamsHierarchicalFact 8th argument (stop_crits) must be a cell array of 2 elements.')
error('matfaust.factparams.ParamsHierarchicalFact 6th argument (stop_crits) must be a cell array of 2 elements.')
end
for i = 1:length(stop_crits)
if(~ isa(stop_crits{i}, matfaust.factparams.StoppingCriterion))
error('matfaust.factparams.ParamsHierarchicalFact 8th argument (stop_crits) must contain matfaust.factparams.StoppingCriterion objects.')
error('matfaust.factparams.ParamsHierarchicalFact 6th argument (stop_crits) must contain matfaust.factparams.StoppingCriterion objects.')
end
end
end
if(~ islogical(is_fact_side_left))
error('matfaust.factparams.ParamsHierarchicalFact 13th argument (is_fact_side_left) must be logical.')
error('matfaust.factparams.ParamsHierarchicalFact 11th argument (is_fact_side_left) must be logical.')
end
p.stop_crits = stop_crits;
p.is_fact_side_left = is_fact_side_left;
p.data_num_rows = data_num_rows;
p.data_num_cols = data_num_cols;
p.is_fact_side_left = is_fact_side_left;
% auto-deduced to-factorize-matrix dim. sizes
p.data_num_rows = p.constraints{1}.num_rows;
p.data_num_cols = p.constraints{end}.num_cols;
end
end
end
......@@ -89,6 +89,9 @@ classdef FaustFactory
import matfaust.Faust
mex_constraints = cell(1, length(p.constraints));
matfaust.FaustFactory.check_fact_mat('FaustFactory.fact_palm4msa', M)
if(~ isa(p ,'matfaust.factparams.ParamsPalm4MSA'))
error('p must be a ParamsPalm4MSA object.')
end
for i=1:length(p.constraints)
cur_cell = cell(1, 4);
cur_cell{1} = p.constraints{i}.name.conv2str();
......@@ -97,6 +100,9 @@ classdef FaustFactory
cur_cell{4} = p.constraints{i}.num_cols;
mex_constraints{i} = cur_cell;
end
if(~ p.is_mat_consistent(M))
error('M''s number of columns must be consistent with the last residuum constraint defined in p. Likewise its number of rows must be consistent with the first factor constraint defined in p.')
end
% put mex_constraints in a cell array again because mex eats one level of array
mex_params = struct('data', M, 'nfacts', p.num_facts, 'cons', {mex_constraints}, 'init_facts', {p.init_facts}, 'niter', p.stop_crit.num_its, 'sc_is_criterion_error', p.stop_crit.is_criterion_error, 'sc_error_treshold', p.stop_crit.error_treshold, 'sc_max_num_its', p.stop_crit.max_num_its, 'update_way', p.is_update_way_R2L);
if(isreal(M))
......@@ -134,7 +140,7 @@ classdef FaustFactory
%> res_cons{3} = ConstraintInt('sp', 32, 32, 333);
%> stop_crit = StoppingCriterion(200);
%> stop_crit2 = StoppingCriterion(200);
%> params = ParamsHierarchicalFact(num_facts, is_update_way_R2L, init_lambda, fact_cons, res_cons, size(M,1), size(M,2), {stop_crit, stop_crit2});
%> params = ParamsHierarchicalFact(num_facts, is_update_way_R2L, init_lambda, fact_cons, res_cons, {stop_crit, stop_crit2});
%> F = FaustFactory.fact_hierarchical(M, params)
%> @endcode
%> Faust::HierarchicalFact<FPP,DEVICE>::compute_facts : factorisation 1/3<br/>
......@@ -155,6 +161,9 @@ classdef FaustFactory
import matfaust.factparams.*
mex_constraints = cell(2, p.num_facts-1);
matfaust.FaustFactory.check_fact_mat('FaustFactory.fact_hierarchical', M)
if(~ isa(p ,'ParamsHierarchicalFact'))
error('p must be a ParamsHierarchicalFact object.')
end
%mex_fact_constraints = cell(1, p.num_facts-1)
for i=1:p.num_facts-1
cur_cell = cell(1, 4);
......@@ -175,6 +184,10 @@ classdef FaustFactory
%mex_residuum_constraints{i} = cur_cell;
mex_constraints{2,i} = cur_cell;
end
if(~ p.is_mat_consistent(M))
error('M''s number of columns must be consistent with the last residuum constraint defined in p. Likewise its number of rows must be consistent with the first factor constraint defined in p.')
end
% the setters for num_rows/cols verifies consistency with constraints
mex_params = struct('data', M, 'nfacts', p.num_facts, 'cons', {mex_constraints}, 'niter1', p.stop_crits{1}.num_its,'niter2', p.stop_crits{2}.num_its, 'sc_is_criterion_error', p.stop_crits{1}.is_criterion_error, 'sc_error_treshold', p.stop_crits{1}.error_treshold, 'sc_max_num_its', p.stop_crits{1}.max_num_its, 'sc_is_criterion_error2', p.stop_crits{2}.is_criterion_error, 'sc_error_treshold2', p.stop_crits{2}.error_treshold, 'sc_max_num_its2', p.stop_crits{2}.max_num_its, 'nrow', p.data_num_rows, 'ncol', p.data_num_cols, 'fact_side', p.is_fact_side_left, 'update_way', p.is_update_way_R2L);
if(isreal(M))
[lambda, core_obj] = mexHierarchical_factReal(M, mex_params);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment