Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 95936673 authored by hhakim's avatar hhakim
Browse files

Add alt. signature to Constraint* constructors to ligthen instantiation code (pyfaut and matfaust).

Example got matfaust:
Before it was necessary to write something like: fact_cons{1} = ConstraintInt(ConstraintName(ConstraintName.SPLIN), 500, 32, 5);
Now it's easier with equivalent code: fact_cons{1} = ConstraintInt('splin', 500, 32, 5);
The former code is still valid though.
parent 413d61b7
Branches
Tags
No related merge requests found
...@@ -10,6 +10,10 @@ classdef ConstraintGeneric ...@@ -10,6 +10,10 @@ classdef ConstraintGeneric
end end
methods methods
function constraint = ConstraintGeneric(name, num_rows, num_cols, param) function constraint = ConstraintGeneric(name, num_rows, num_cols, param)
%ENOTE: ischar(name{:})
if(ischar(name) || iscell(name) && ischar(name{:}))
name = matfaust.factparams.ConstraintName(name);
end
constraint.name = name; constraint.name = name;
if(~ isreal(num_rows) || ~ isscalar(num_rows)) if(~ isreal(num_rows) || ~ isscalar(num_rows))
error('ConstraintGeneric 2nd argument must be an integer.') error('ConstraintGeneric 2nd argument must be an integer.')
......
...@@ -11,11 +11,11 @@ classdef ConstraintInt < matfaust.factparams.ConstraintGeneric ...@@ -11,11 +11,11 @@ classdef ConstraintInt < matfaust.factparams.ConstraintGeneric
if(~ isreal(param) || ~ isscalar(param)) if(~ isreal(param) || ~ isscalar(param))
error('ConstraintInt must receive an integer as param argument.') error('ConstraintInt must receive an integer as param argument.')
end end
if(~ isa(name, 'matfaust.factparams.ConstraintName') || ~ name.is_int_constraint()) constraint = constraint@matfaust.factparams.ConstraintGeneric(name, num_rows, num_cols, floor(param));
if(~ isa(constraint.name, 'matfaust.factparams.ConstraintName') || ~ constraint.name.is_int_constraint())
error(['ConstraintInt first argument must be a ConstraintName with a int type name ', ... error(['ConstraintInt first argument must be a ConstraintName with a int type name ', ...
'(name.is_int_constraint() must return True).']) '(name.is_int_constraint() must return True).'])
end end
constraint = constraint@matfaust.factparams.ConstraintGeneric(name, num_rows, num_cols, floor(param));
end end
end end
end end
...@@ -12,7 +12,7 @@ classdef ConstraintMat < matfaust.factparams.ConstraintGeneric ...@@ -12,7 +12,7 @@ classdef ConstraintMat < matfaust.factparams.ConstraintGeneric
error('ConstraintMat must receive a matrix as param argument.') error('ConstraintMat must receive a matrix as param argument.')
end end
constraint = constraint@matfaust.factparams.ConstraintGeneric(name, size(param, 1), size(param, 2), param); constraint = constraint@matfaust.factparams.ConstraintGeneric(name, size(param, 1), size(param, 2), param);
if(~ isa(name, 'matfaust.factparams.ConstraintName') || ~ name.is_mat_constraint()) if(~ isa(constraint.name, 'matfaust.factparams.ConstraintName') || ~ constraint.name.is_mat_constraint())
error('ConstraintMat first argument must be a ConstraintName with a matrix type name.') error('ConstraintMat first argument must be a ConstraintName with a matrix type name.')
end end
end end
......
...@@ -21,6 +21,9 @@ classdef ConstraintName ...@@ -21,6 +21,9 @@ classdef ConstraintName
function cons_name = ConstraintName(name) function cons_name = ConstraintName(name)
import matfaust.factparams.ConstraintName import matfaust.factparams.ConstraintName
if(ischar(name) || iscell(name))
name = ConstraintName.str2name_int(name);
end
if(name > ConstraintName.NORMLIN || name < ConstraintName.SP) %|| name == ConstraintName.BLKDIAG) if(name > ConstraintName.NORMLIN || name < ConstraintName.SP) %|| name == ConstraintName.BLKDIAG)
msg = 'name must be an integer among ConstraintName.SP, ConstraintName.SPCOL, ConstraintName.NORMCOL, ConstraintName.SPLINCOL, ConstraintName.CONST, ConstraintName.SP_POS, ConstraintName.SUPP, ConstraintName.NORMLIN'; msg = 'name must be an integer among ConstraintName.SP, ConstraintName.SPCOL, ConstraintName.NORMCOL, ConstraintName.SPLINCOL, ConstraintName.CONST, ConstraintName.SP_POS, ConstraintName.SUPP, ConstraintName.NORMLIN';
error(msg) error(msg)
...@@ -66,12 +69,42 @@ classdef ConstraintName ...@@ -66,12 +69,42 @@ classdef ConstraintName
str = 'supp'; str = 'supp';
case obj.CONST; case obj.CONST;
str = 'const'; str = 'const';
%case obj.BLKDIAG; %case obj.BLKDIAG;
% str = 'blkdiag' % str = 'blkdiag'
otherwise otherwise
error('Unknown name') error('Unknown name')
end end
end end
end
methods(Static)
function id = str2name_int(str)
import matfaust.factparams.ConstraintName
err_msg = 'Invalid argument to designate a ConstraintName.';
if(~ ischar(str) && ~ iscell(str))
error(err_msg)
end
switch(str)
case 'sp'
id = ConstraintName.SP;
case 'splin'
id = ConstraintName.SPLIN;
case 'spcol'
id = ConstraintName.SPCOL;
case 'splincol'
id = ConstraintName.SPLINCOL;
case 'sppos'
id = ConstraintName.SP_POS;
case 'normcol'
id = ConstraintName.NORMCOL;
case 'normlin'
id = ConstraintName.NORMLIN;
case 'supp'
id = ConstraintName.SUPP;
case 'const'
id = ConstraintName.CONST;
otherwise
error(err_msg)
end
end
end end
end end
...@@ -11,7 +11,7 @@ classdef ConstraintReal < matfaust.factparams.ConstraintGeneric ...@@ -11,7 +11,7 @@ classdef ConstraintReal < matfaust.factparams.ConstraintGeneric
error('ConstraintReal must receive a real as param argument.') error('ConstraintReal must receive a real as param argument.')
end end
constraint = constraint@matfaust.factparams.ConstraintGeneric(name, num_rows, num_cols, param); constraint = constraint@matfaust.factparams.ConstraintGeneric(name, num_rows, num_cols, param);
if(~ isa(name, 'matfaust.factparams.ConstraintName') || ~ name.is_real_constraint()) if(~ isa(constraint.name, 'matfaust.factparams.ConstraintName') || ~ constraint.name.is_real_constraint())
error('ConstraintReal first argument must be a ConstraintName with a real type name.') error('ConstraintReal first argument must be a ConstraintName with a real type name.')
end end
end end
......
...@@ -39,19 +39,19 @@ classdef ParamsPalm4MSA < matfaust.factparams.ParamsFact ...@@ -39,19 +39,19 @@ classdef ParamsPalm4MSA < matfaust.factparams.ParamsFact
p = p@matfaust.factparams.ParamsFact(num_facts, is_update_way_R2L, init_lambda, ... p = p@matfaust.factparams.ParamsFact(num_facts, is_update_way_R2L, init_lambda, ...
constraints, step_size, constant_step_size, is_verbose); constraints, step_size, constant_step_size, is_verbose);
if(is_init_facts_to_default || iscell(init_facts) && length(init_facts) == 0) if(is_init_facts_to_default || iscell(init_facts) && length(init_facts) == 0)
init_facts = cell(num_facts, 1) init_facts = cell(num_facts, 1);
if(is_update_way_R2L) if(is_update_way_R2L)
zeros_id = num_facts zeros_id = num_facts;
else else
zeros_id = 1 zeros_id = 1;
end end
for i=1:num_facts for i=1:num_facts
if(i ~= zeros_id) if(i ~= zeros_id)
init_facts{i} = eye(constraints{i}.num_rows, constraints{i}.num_cols) init_facts{i} = eye(constraints{i}.num_rows, constraints{i}.num_cols);
end end
end end
init_facts{zeros_id} = ... init_facts{zeros_id} = ...
zeros(constraints{zeros_id}.num_rows, constraints{zeros_id}.num_cols) zeros(constraints{zeros_id}.num_rows, constraints{zeros_id}.num_cols);
elseif(~ iscell(init_facts)) % TODO: check init_facts length elseif(~ iscell(init_facts)) % TODO: check init_facts length
error('matfaust.factparams.ParamsFactPalm4MSA 4th argument (init_facts) must be a cell array.') error('matfaust.factparams.ParamsFactPalm4MSA 4th argument (init_facts) must be a cell array.')
else else
...@@ -62,7 +62,7 @@ classdef ParamsPalm4MSA < matfaust.factparams.ParamsFact ...@@ -62,7 +62,7 @@ classdef ParamsPalm4MSA < matfaust.factparams.ParamsFact
end end
end end
end end
p.init_facts = init_facts p.init_facts = init_facts;
if(~ isa(stop_crit, 'matfaust.factparams.StoppingCriterion')) if(~ isa(stop_crit, 'matfaust.factparams.StoppingCriterion'))
error('matfaust.factparams.ParamsPalm4MSA argument (stop_crit) must be a matfaust.factparams.StoppingCriterion objects.') error('matfaust.factparams.ParamsPalm4MSA argument (stop_crit) must be a matfaust.factparams.StoppingCriterion objects.')
end end
......
...@@ -65,13 +65,13 @@ classdef FaustFactory ...@@ -65,13 +65,13 @@ classdef FaustFactory
%> @code %> @code
%> import matfaust.* %> import matfaust.*
%> import matfaust.factparams.* %> import matfaust.factparams.*
%> num_facts = 2 %> num_facts = 2;
%> is_update_way_R2L = false %> is_update_way_R2L = false;
%> init_lambda = 1.0 %> init_lambda = 1.0;
%> M = rand(500, 32) %> M = rand(500, 32);
%> cons = cell(2,1) %> cons = cell(2,1);
%> cons{1} = ConstraintInt(ConstraintName(ConstraintName.SPLIN), 500, 32, 5); %> cons{1} = ConstraintInt('splin', 500, 32, 5);
%> cons{2} = ConstraintReal(ConstraintName(ConstraintName.NORMCOL), 32, 32, 1.0); %> cons{2} = ConstraintReal('normcol', 32, 32, 1.0);
%> stop_crit = StoppingCriterion(200); %> stop_crit = StoppingCriterion(200);
%> params = ParamsPalm4MSA(num_facts, is_update_way_R2L, init_lambda, cons, stop_crit); %> params = ParamsPalm4MSA(num_facts, is_update_way_R2L, init_lambda, cons, stop_crit);
%> F = FaustFactory.fact_palm4msa(M, params) %> F = FaustFactory.fact_palm4msa(M, params)
...@@ -126,12 +126,12 @@ classdef FaustFactory ...@@ -126,12 +126,12 @@ classdef FaustFactory
%> M = rand(500, 32); %> M = rand(500, 32);
%> fact_cons = cell(3, 1); %> fact_cons = cell(3, 1);
%> res_cons = cell(3, 1); %> res_cons = cell(3, 1);
%> fact_cons{1} = ConstraintInt(ConstraintName(ConstraintName.SPLIN), 500, 32, 5); %> fact_cons{1} = ConstraintInt('splin', 500, 32, 5);
%> fact_cons{2} = ConstraintInt(ConstraintName(ConstraintName.SP), 32, 32, 96); %> fact_cons{2} = ConstraintInt('sp', 32, 32, 96);
%> fact_cons{3} = ConstraintInt(ConstraintName(ConstraintName.SP), 32, 32, 96); %> fact_cons{3} = ConstraintInt('sp', 32, 32, 96);
%> res_cons{1} = ConstraintReal(ConstraintName(ConstraintName.NORMCOL), 32, 32, 1); %> res_cons{1} = ConstraintReal('normcol', 32, 32, 1);
%> res_cons{2} = ConstraintInt(ConstraintName(ConstraintName.SP), 32, 32, 666); %> res_cons{2} = ConstraintInt('sp', 32, 32, 666);
%> res_cons{3} = ConstraintInt(ConstraintName(ConstraintName.SP), 32, 32, 333); %> res_cons{3} = ConstraintInt('sp', 32, 32, 333);
%> stop_crit = StoppingCriterion(200); %> stop_crit = StoppingCriterion(200);
%> stop_crit2 = StoppingCriterion(200); %> stop_crit2 = StoppingCriterion(200);
%> params = ParamsHierarchicalFact(num_facts, is_update_way_R2L, init_lambda, fact_cons, res_cons, size(M,1), size(M,2), {stop_crit, stop_crit2}); %> params = ParamsHierarchicalFact(num_facts, is_update_way_R2L, init_lambda, fact_cons, res_cons, size(M,1), size(M,2), {stop_crit, stop_crit2});
......
...@@ -1362,8 +1362,8 @@ class FaustFactory: ...@@ -1362,8 +1362,8 @@ class FaustFactory:
>>> is_update_way_R2L = False >>> is_update_way_R2L = False
>>> init_lambda = 1.0 >>> init_lambda = 1.0
>>> M = np.random.rand(500, 32) >>> M = np.random.rand(500, 32)
>>> cons1 = ConstraintInt(ConstraintName(ConstraintName.SPLIN), 500, 32, 5) >>> cons1 = ConstraintInt('splin', 500, 32, 5)
>>> cons2 = ConstraintReal(ConstraintName(ConstraintName.NORMCOL), 32, 32, 1.0) >>> cons2 = ConstraintReal('normcol', 32, 32, 1.0)
>>> stop_crit = StoppingCriterion(num_its=200) >>> stop_crit = StoppingCriterion(num_its=200)
>>> # default step_size is 1e-16 >>> # default step_size is 1e-16
>>> param = ParamsPalm4MSA(num_facts, is_update_way_R2L, init_lambda, >>> param = ParamsPalm4MSA(num_facts, is_update_way_R2L, init_lambda,
...@@ -1410,12 +1410,12 @@ class FaustFactory: ...@@ -1410,12 +1410,12 @@ class FaustFactory:
>>> is_update_way_R2L = False >>> is_update_way_R2L = False
>>> init_lambda = 1.0 >>> init_lambda = 1.0
>>> M = np.random.rand(500, 32) >>> M = np.random.rand(500, 32)
>>> fact0_cons = ConstraintInt(ConstraintName(ConstraintName.SPLIN), 500, 32, 5) >>> fact0_cons = ConstraintInt('splin', 500, 32, 5)
>>> fact1_cons = ConstraintInt(ConstraintName(ConstraintName.SP), 32, 32, 96) >>> fact1_cons = ConstraintInt('sp', 32, 32, 96)
>>> fact2_cons = ConstraintInt(ConstraintName(ConstraintName.SP), 32, 32, 96) >>> fact2_cons = ConstraintInt('sp', 32, 32, 96)
>>> res0_cons = ConstraintReal(ConstraintName(ConstraintName.NORMCOL), 32, 32, 1) >>> res0_cons = ConstraintReal('normcol', 32, 32, 1)
>>> res1_cons = ConstraintInt(ConstraintName(ConstraintName.SP), 32, 32, 666) >>> res1_cons = ConstraintInt('sp', 32, 32, 666)
>>> res2_cons = ConstraintInt(ConstraintName(ConstraintName.SP), 32, 32, 333) >>> res2_cons = ConstraintInt('sp', 32, 32, 333)
>>> stop_crit1 = StoppingCriterion(num_its=200) >>> stop_crit1 = StoppingCriterion(num_its=200)
>>> stop_crit2 = StoppingCriterion(num_its=200) >>> stop_crit2 = StoppingCriterion(num_its=200)
>>> param = ParamsHierarchicalFact(num_facts, is_update_way_R2L, init_lambda, >>> param = ParamsHierarchicalFact(num_facts, is_update_way_R2L, init_lambda,
......
...@@ -35,6 +35,8 @@ class ConstraintGeneric(object): ...@@ -35,6 +35,8 @@ class ConstraintGeneric(object):
cons_value: the parameter value of the constraint. cons_value: the parameter value of the constraint.
""" """
if(isinstance(name, str)):
name = ConstraintName(name)
self._name = name self._name = name
self._num_rows = num_rows self._num_rows = num_rows
self._num_cols = num_cols self._num_cols = num_cols
...@@ -78,7 +80,7 @@ class ConstraintInt(ConstraintGeneric): ...@@ -78,7 +80,7 @@ class ConstraintInt(ConstraintGeneric):
if(not isinstance(cons_value, np.int)): if(not isinstance(cons_value, np.int)):
raise TypeError('ConstraintInt must receive a int as cons_value ' raise TypeError('ConstraintInt must receive a int as cons_value '
'argument.') 'argument.')
if(not isinstance(name, ConstraintName) or not name.is_int_constraint()): if(not isinstance(self._name, ConstraintName) or not self._name.is_int_constraint()):
raise TypeError('ConstraintInt first argument must be a ' raise TypeError('ConstraintInt first argument must be a '
'ConstraintName with a int type name ' 'ConstraintName with a int type name '
'(name.is_int_constraint() must return True).') '(name.is_int_constraint() must return True).')
...@@ -92,7 +94,7 @@ class ConstraintMat(ConstraintGeneric): ...@@ -92,7 +94,7 @@ class ConstraintMat(ConstraintGeneric):
raise TypeError('ConstraintMat must receive a numpy matrix as cons_value ' raise TypeError('ConstraintMat must receive a numpy matrix as cons_value '
'argument.') 'argument.')
self.cons_value = float(self.cons_value) self.cons_value = float(self.cons_value)
if(not isinstance(name, ConstraintName) or not name.is_mat_constraint()): if(not isinstance(self._name, ConstraintName) or not self._name.is_mat_constraint()):
raise TypeError('ConstraintMat first argument must be a ' raise TypeError('ConstraintMat first argument must be a '
'ConstraintName with a matrix type name ' 'ConstraintName with a matrix type name '
'(name.is_mat_constraint() must return True)') '(name.is_mat_constraint() must return True)')
...@@ -124,7 +126,7 @@ class ConstraintReal(ConstraintGeneric): ...@@ -124,7 +126,7 @@ class ConstraintReal(ConstraintGeneric):
raise TypeError('ConstraintReal must receive a float as cons_value ' raise TypeError('ConstraintReal must receive a float as cons_value '
'argument.') 'argument.')
self._cons_value = float(self._cons_value) self._cons_value = float(self._cons_value)
if(not isinstance(name, ConstraintName) or not name.is_real_constraint()): if(not isinstance(self._name, ConstraintName) or not self._name.is_real_constraint()):
raise TypeError('ConstraintReal first argument must be a ' raise TypeError('ConstraintReal first argument must be a '
'ConstraintName with a real type name ' 'ConstraintName with a real type name '
'(name.is_real_constraint() must return True).') '(name.is_real_constraint() must return True).')
...@@ -160,6 +162,8 @@ class ConstraintName: ...@@ -160,6 +162,8 @@ class ConstraintName:
NORMLIN = 9 # Real Constraint NORMLIN = 9 # Real Constraint
def __init__(self, name): def __init__(self, name):
if(isinstance(name,str)):
name = ConstraintName.str2name_int(name)
if(not isinstance(name, np.int) or name < ConstraintName.SP or name > ConstraintName.NORMLIN): if(not isinstance(name, np.int) or name < ConstraintName.SP or name > ConstraintName.NORMLIN):
raise ValueError("name must be an integer among ConstraintName.SP," raise ValueError("name must be an integer among ConstraintName.SP,"
"ConstraintName.SPCOL, ConstraintName.NORMCOL," "ConstraintName.SPCOL, ConstraintName.NORMCOL,"
...@@ -179,6 +183,32 @@ class ConstraintName: ...@@ -179,6 +183,32 @@ class ConstraintName:
def is_mat_constraint(self): def is_mat_constraint(self):
return self.name in [ConstraintName.SUPP, ConstraintName.CONST ] return self.name in [ConstraintName.SUPP, ConstraintName.CONST ]
@staticmethod
def str2name_int(_str):
err_msg = "Invalid argument to designate a ConstraintName."
if(not isinstance(_str, str)):
raise ValueError(err_msg)
if(_str == 'sp'):
id = ConstraintName.SP
elif(_str == 'splin'):
id = ConstraintName.SPLIN
elif(_str == 'spcol'):
id = ConstraintName.SPCOL
elif(_str == 'splincol'):
id = ConstraintName.SPLINCOL
elif(_str == 'sppos'):
id = ConstraintName.SP_POS
elif(_str == 'normcol'):
id = ConstraintName.NORMCOL
elif(_str == 'normlin'):
id = ConstraintName.NORMLIN
elif(_str == 'supp'):
id = ConstraintName.SUPP
elif(_str == 'const'):
id = ConstraintName.CONST
else:
raise ValueError(err_msg)
return id
class ParamsFact(object): class ParamsFact(object):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment