Mentions légales du service

Skip to content
Snippets Groups Projects
Commit bb39efc2 authored by hhakim's avatar hhakim
Browse files

Ensure that matfaust PALM4MSA init_facts and constraints matrices passed as...

Ensure that matfaust PALM4MSA init_facts and constraints matrices passed as parameters are in the same field (complex, real) and the same precision (double, single) as the matrix to factorize.

If the assertion is not respected a comprehensive error stops the execution.
Fix also init_facts customized parameter always ignored (using default init_facts).
parent 57d08b13
No related branches found
No related tags found
No related merge requests found
......@@ -96,7 +96,8 @@ function [F,lambda] = palm4msa(M, p, varargin)
dev = 'gpu';
end
if(isreal(M))
init_faust = matfaust.Faust(p.init_facts, 'dtype', dtype, 'dev', dev);
% do not use p.init_facts because if default init_facts is used it is created on the fly when palm4msa is called (in order to be type-class consistent with M)
init_faust = matfaust.Faust(mex_params.init_facts, 'dtype', dtype, 'dev', dev);
if(gpu)
if(is_float)
[lambda, core_obj] = mexPALM4MSA2020_gpu2RealFloat(mex_params, get_handle(init_faust));
......@@ -111,7 +112,7 @@ function [F,lambda] = palm4msa(M, p, varargin)
end
end
else
init_faust = complex(matfaust.Faust(p.init_facts, 'dev', dev));
init_faust = complex(matfaust.Faust(mex_params.init_facts, 'dev', dev));
if(gpu)
[lambda, core_obj] = mexPALM4MSA2020_gpu2Cplx(mex_params, get_handle(init_faust));
else
......
......@@ -61,14 +61,14 @@ classdef ParamsPalm4MSA < matfaust.factparams.ParamsFact
% parent constructor handles verification for its own arguments
end
p = p@matfaust.factparams.ParamsFact(num_facts, constraints, parent_args{:});
init_facts_name = p.OPT_ARG_NAMES{p.IDX_INIT_FACTS};
init_facts_name = p.OPT_ARG_NAMES2{p.IDX_INIT_FACTS};
try
init_facts = opt_arg_map(init_facts_name);
catch
% arg int_facts not passed
% arg init_facts not passed
end
if(~ exist(init_facts_name) || iscell(init_facts) && length(init_facts) == 0)
init_facts = p.get_default_init_facts(num_facts);
init_facts = {'default_init_facts'};
elseif(~ iscell(init_facts)) % TODO: check init_facts length
error(['matfaust.factparams.ParamsFactPalm4MSA argument ' init_facts_name ' must be a cell array.'])
else
......@@ -91,16 +91,51 @@ classdef ParamsPalm4MSA < matfaust.factparams.ParamsFact
function mex_params = to_mex_struct(this, M)
mex_constraints = cell(1, length(this.constraints));
is_single = @(mat) strcmp('single', class(mat));
M_is_single = is_single(M);
M_is_real = isreal(M);
for i=1:length(this.constraints)
cur_cell = cell(1, 4);
cur_cell{1} = this.constraints{i}.name.conv2str();
if(isa(this.constraints{i}, 'matfaust.factparams.ConstraintMat'))
if(~ M_is_single == is_single(this.constraints{i}.param))
error(['ParamsPalm4MSA.constraints, constraint #' int2str(i) ' must be in single precision (resp. double) if the matrix to factorize is in single precision (resp. double).'])
end
if(~ M_is_real == isreal(this.constraints{i}.param))
error(['ParamsPalm4MSA.constraints, constraint #' int2str(i) ' must be real (resp. complex) if the matrix to factorize is real (resp. complex).'])
end
end
cur_cell{2} = this.constraints{i}.param;
cur_cell{3} = this.constraints{i}.num_rows;
cur_cell{4} = this.constraints{i}.num_cols;
mex_constraints{i} = cur_cell;
end
if(iscell(this.init_facts) && strcmp(this.init_facts{1}, 'default_init_facts'))
if(strcmp(class(M), 'single'))
class_func = @(mat) single(mat);
else
class_func = @(mat) double(mat);
end
if(isreal(M))
type_func = @(mat) real(mat);
else
type_func = @(mat) complex(mat);
end
init_facts = this.get_default_init_facts(length(this.constraints), class_func, type_func);
else
init_facts = this.init_facts;
% verify M / init_facts type/precision consistency
for i=1:length(init_facts)
if(~ M_is_real == isreal(init_facts{i}))
error(['ParamsPalm4MSA.init_facts, factor ' int2str(i) ' must be real (resp. complex) if the matrix to factorize is real (resp. complex).'])
end
if(~ M_is_single == is_single(this.init_facts{i}))
error(['ParamsPalm4MSA.init_facts, factor ' int2str(i) ' be in single precision (resp. double) if the matrix to factorize is in single precision (resp. double).'])
end
end
end
% put mex_constraints in a cell array again because mex eats one level of array
mex_params = struct('data', M, 'nfacts', this.num_facts, 'cons', {mex_constraints}, 'init_facts', {this.init_facts}, 'niter', this.stop_crit.num_its, 'sc_is_criterion_error', this.stop_crit.is_criterion_error, 'sc_error_treshold', this.stop_crit.tol, 'sc_max_num_its', this.stop_crit.maxiter, 'update_way', this.is_update_way_R2L, 'grad_calc_opt_mode', this.grad_calc_opt_mode, 'constant_step_size', this.constant_step_size, 'step_size', this.step_size, 'verbose', this.is_verbose, 'norm2_max_iter', this.norm2_max_iter, 'norm2_threshold', this.norm2_threshold, 'init_lambda', this.init_lambda, 'factor_format', matfaust.factparams.ParamsFact.factor_format_str2int(this.factor_format), 'packing_RL', this.packing_RL, 'no_normalization', this.no_normalization, 'no_lambda', this.no_lambda);
mex_params = struct('data', M, 'nfacts', this.num_facts, 'cons', {mex_constraints}, 'init_facts', {init_facts}, 'niter', this.stop_crit.num_its, 'sc_is_criterion_error', this.stop_crit.is_criterion_error, 'sc_error_treshold', this.stop_crit.tol, 'sc_max_num_its', this.stop_crit.maxiter, 'update_way', this.is_update_way_R2L, 'grad_calc_opt_mode', this.grad_calc_opt_mode, 'constant_step_size', this.constant_step_size, 'step_size', this.step_size, 'verbose', this.is_verbose, 'norm2_max_iter', this.norm2_max_iter, 'norm2_threshold', this.norm2_threshold, 'init_lambda', this.init_lambda, 'factor_format', matfaust.factparams.ParamsFact.factor_format_str2int(this.factor_format), 'packing_RL', this.packing_RL, 'no_normalization', this.no_normalization, 'no_lambda', this.no_lambda);
if(~ (islogical(this.use_MHTP) && this.use_MHTP == false))
% use_MHTP must be a MHTPParams if not false (cf. ParamsFact)
if(~ isa(this.use_MHTP, 'matfaust.factparams.MHTPParams'))
......@@ -116,7 +151,7 @@ classdef ParamsPalm4MSA < matfaust.factparams.ParamsFact
end
end
methods
function init_facts = get_default_init_facts(p, num_facts)
function init_facts = get_default_init_facts(p, num_facts, class_func, type_func)
init_facts = cell(num_facts, 1);
if(p.is_update_way_R2L)
zeros_id = num_facts;
......@@ -125,11 +160,11 @@ classdef ParamsPalm4MSA < matfaust.factparams.ParamsFact
end
for i=1:num_facts
if(i ~= zeros_id)
init_facts{i} = eye(p.constraints{i}.num_rows, p.constraints{i}.num_cols);
init_facts{i} = type_func(class_func(eye(p.constraints{i}.num_rows, p.constraints{i}.num_cols)));
end
end
init_facts{zeros_id} = ...
zeros(p.constraints{zeros_id}.num_rows, p.constraints{zeros_id}.num_cols);
type_func(class_func(zeros(p.constraints{zeros_id}.num_rows, p.constraints{zeros_id}.num_cols)));
end
end
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment