Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 17c9d8bf authored by hhakim's avatar hhakim
Browse files

Update pyfaust wrapper (constraint and param classes for facto.).

- Replace pyfaust.ConstraintScalar by pyfaust.ConstraintInt and ConstraintReal (like in matfaust).
- Reinforce the arguments checking in constructors of ConstraintInt, ConstraintReal, StoppingCriterion, ParamsHierarchicalFact, ParamsPalm4MSA.
- Other minor changes (including in matfaust).
parent be9ec34c
No related branches found
No related tags found
No related merge requests found
...@@ -305,8 +305,8 @@ class TestFaustFactory(unittest.TestCase): ...@@ -305,8 +305,8 @@ class TestFaustFactory(unittest.TestCase):
def testFactPalm4MSA(self): def testFactPalm4MSA(self):
print("Test FaustFactory.fact_palm4msa()") print("Test FaustFactory.fact_palm4msa()")
from pyfaust import FaustFactory, ParamsPalm4MSA, ConstraintScalar,\ from pyfaust import FaustFactory, ParamsPalm4MSA, ConstraintReal,\
ConstraintName, StoppingCriterion ConstraintInt, ConstraintName, StoppingCriterion
num_facts = 2 num_facts = 2
is_update_way_R2L = False is_update_way_R2L = False
init_lambda = 1.0 init_lambda = 1.0
...@@ -317,8 +317,8 @@ class TestFaustFactory(unittest.TestCase): ...@@ -317,8 +317,8 @@ class TestFaustFactory(unittest.TestCase):
M = \ M = \
loadmat(sys.path[-1]+"/../../../misc/data/mat/config_compared_palm2.mat")['data'] loadmat(sys.path[-1]+"/../../../misc/data/mat/config_compared_palm2.mat")['data']
# default step_size # default step_size
cons1 = ConstraintScalar(ConstraintName(ConstraintName.SPLIN), 500, 32, 5) cons1 = ConstraintInt(ConstraintName(ConstraintName.SPLIN), 500, 32, 5)
cons2 = ConstraintScalar(ConstraintName(ConstraintName.NORMCOL), 32, cons2 = ConstraintReal(ConstraintName(ConstraintName.NORMCOL), 32,
32, 1.0) 32, 1.0)
stop_crit = StoppingCriterion(num_its=200) stop_crit = StoppingCriterion(num_its=200)
param = ParamsPalm4MSA(num_facts, is_update_way_R2L, init_lambda, param = ParamsPalm4MSA(num_facts, is_update_way_R2L, init_lambda,
...@@ -337,8 +337,8 @@ class TestFaustFactory(unittest.TestCase): ...@@ -337,8 +337,8 @@ class TestFaustFactory(unittest.TestCase):
def testFactHierarch(self): def testFactHierarch(self):
print("Test FaustFactory.fact_hierarchical()") print("Test FaustFactory.fact_hierarchical()")
from pyfaust import FaustFactory, ParamsHierarchicalFact, ConstraintScalar,\ from pyfaust import FaustFactory, ParamsHierarchicalFact, ConstraintReal,\
ConstraintName, StoppingCriterion ConstraintInt, ConstraintName, StoppingCriterion
num_facts = 4 num_facts = 4
is_update_way_R2L = False is_update_way_R2L = False
init_lambda = 1.0 init_lambda = 1.0
...@@ -350,16 +350,16 @@ class TestFaustFactory(unittest.TestCase): ...@@ -350,16 +350,16 @@ class TestFaustFactory(unittest.TestCase):
M = \ M = \
loadmat(sys.path[-1]+"/../../../misc/data/mat/matrix_hierarchical_fact.mat")['matrix'] loadmat(sys.path[-1]+"/../../../misc/data/mat/matrix_hierarchical_fact.mat")['matrix']
# default step_size # default step_size
cons1 = ConstraintScalar(ConstraintName(ConstraintName.SPLIN), 500, 32, 5) cons1 = ConstraintInt(ConstraintName(ConstraintName.SPLIN), 500, 32, 5)
cons2 = ConstraintScalar(ConstraintName(ConstraintName.SP), 32, cons2 = ConstraintInt(ConstraintName(ConstraintName.SP), 32,
32, 96) 32, 96)
cons3 = ConstraintScalar(ConstraintName(ConstraintName.SP), 32, 32, cons3 = ConstraintInt(ConstraintName(ConstraintName.SP), 32, 32,
96) 96)
cons4 = ConstraintScalar(ConstraintName(ConstraintName.NORMCOL), 32, 32, cons4 = ConstraintReal(ConstraintName(ConstraintName.NORMCOL), 32, 32,
1) 1)
cons5 = ConstraintScalar(ConstraintName(ConstraintName.SP), 32, 32, cons5 = ConstraintInt(ConstraintName(ConstraintName.SP), 32, 32,
666) 666)
cons6 = ConstraintScalar(ConstraintName(ConstraintName.SP), 32, 32, cons6 = ConstraintInt(ConstraintName(ConstraintName.SP), 32, 32,
333) 333)
stop_crit1 = StoppingCriterion(num_its=200) stop_crit1 = StoppingCriterion(num_its=200)
stop_crit2 = StoppingCriterion(num_its=200) stop_crit2 = StoppingCriterion(num_its=200)
......
...@@ -5,7 +5,7 @@ classdef ConstraintInt < matfaust.ConstraintGeneric ...@@ -5,7 +5,7 @@ classdef ConstraintInt < matfaust.ConstraintGeneric
methods methods
function constraint = ConstraintInt(name, num_rows, num_cols, param) function constraint = ConstraintInt(name, num_rows, num_cols, param)
% check param is a int % check param is a int
if(~ isreal(param) || ~ isinteger(int64(param)) || all(size(param) ~= [1, 1])) if(~ isreal(param) || ~ isinteger(int64(param)) || ~ isscalar(param))
error('ConstraintInt must receive an integer as param argument.') error('ConstraintInt must receive an integer as param argument.')
end end
constraint = constraint@matfaust.ConstraintGeneric(name, num_rows, num_cols, floor(param)); constraint = constraint@matfaust.ConstraintGeneric(name, num_rows, num_cols, floor(param));
......
...@@ -5,7 +5,7 @@ classdef ConstraintReal < matfaust.ConstraintGeneric ...@@ -5,7 +5,7 @@ classdef ConstraintReal < matfaust.ConstraintGeneric
methods methods
function constraint = ConstraintReal(name, num_rows, num_cols, param) function constraint = ConstraintReal(name, num_rows, num_cols, param)
% check param is a real scalar % check param is a real scalar
if(~ isreal(param) || all(size(param) ~= [1, 1])) if(~ isreal(param) || ~ isscalar(param))
error('ConstraintReal must receive a real as param argument.') error('ConstraintReal must receive a real as param argument.')
end end
constraint = constraint@matfaust.ConstraintGeneric(name, num_rows, num_cols, param); constraint = constraint@matfaust.ConstraintGeneric(name, num_rows, num_cols, param);
......
...@@ -50,9 +50,9 @@ classdef (Abstract) ParamsFact ...@@ -50,9 +50,9 @@ classdef (Abstract) ParamsFact
if(~ iscell(constraints)) if(~ iscell(constraints))
error('matfaust.ParamsFact 5th argument (constraints) must be a cell array.') error('matfaust.ParamsFact 5th argument (constraints) must be a cell array.')
end end
for i = 1:length(constraints) %TODO: check constraints length for i = 1:length(constraints) %TODO: check constraints length in sub-class
if(~ isa(constraints{i}, 'matfaust.ConstraintGeneric')) if(~ isa(constraints{i}, 'matfaust.ConstraintGeneric'))
error('matfaust.ParamsFact 5th argument (constraints) must contain matfaust.ConstraintName objects.') error('matfaust.ParamsFact 5th argument (constraints) must contain matfaust.ConstraintGeneric objects.')
end end
end end
if(nargin > MIN_NARGIN) if(nargin > MIN_NARGIN)
......
...@@ -962,8 +962,8 @@ class FaustFactory: ...@@ -962,8 +962,8 @@ class FaustFactory:
The Faust object result of the factorization. The Faust object result of the factorization.
Examples: Examples:
>>> from pyfaust import FaustFactory, ParamsPalm4MSA, ConstraintScalar,\ >>> from pyfaust import FaustFactory, ParamsPalm4MSA, ConstraintReal,\
>>> ConstraintName, StoppingCriterion >>> ConstraintInt, ConstraintName, StoppingCriterion
>>> import numpy as np >>> import numpy as np
>>> num_facts = 2 >>> num_facts = 2
>>> is_update_way_R2L = False >>> is_update_way_R2L = False
...@@ -972,8 +972,8 @@ class FaustFactory: ...@@ -972,8 +972,8 @@ class FaustFactory:
>>> init_facts.append(np.zeros([500,32])) >>> init_facts.append(np.zeros([500,32]))
>>> init_facts.append(np.eye(32)) >>> init_facts.append(np.eye(32))
>>> M = np.random.rand(500, 32) >>> M = np.random.rand(500, 32)
>>> cons1 = ConstraintScalar(ConstraintName(ConstraintName.SPLIN), 500, 32, 5) >>> cons1 = ConstraintInt(ConstraintName(ConstraintName.SPLIN), 500, 32, 5)
>>> cons2 = ConstraintScalar(ConstraintName(ConstraintName.NORMCOL), 32, 32, 1.0) >>> cons2 = ConstraintReal(ConstraintName(ConstraintName.NORMCOL), 32, 32, 1.0)
>>> stop_crit = StoppingCriterion(num_its=200) >>> stop_crit = StoppingCriterion(num_its=200)
>>> # default step_size is 1e-16 >>> # default step_size is 1e-16
>>> param = ParamsPalm4MSA(num_facts, is_update_way_R2L, init_lambda, >>> param = ParamsPalm4MSA(num_facts, is_update_way_R2L, init_lambda,
...@@ -1011,8 +1011,8 @@ class FaustFactory: ...@@ -1011,8 +1011,8 @@ class FaustFactory:
The Faust object result of the factorization. The Faust object result of the factorization.
Examples: Examples:
>>> from pyfaust import FaustFactory, ParamsHierarchicalFact, ConstraintScalar,\ >>> from pyfaust import FaustFactory, ParamsHierarchicalFact, ConstraintReal,\
>>> ConstraintName, StoppingCriterion >>> ConstraintInt, ConstraintName, StoppingCriterion
>>> import numpy as np >>> import numpy as np
>>> num_facts = 4 >>> num_facts = 4
>>> is_update_way_R2L = False >>> is_update_way_R2L = False
...@@ -1022,12 +1022,12 @@ class FaustFactory: ...@@ -1022,12 +1022,12 @@ class FaustFactory:
>>> for i in range(1,num_facts): >>> for i in range(1,num_facts):
>>> init_facts.append(np.zeros([32,32])) >>> init_facts.append(np.zeros([32,32]))
>>> M = np.random.rand(500, 32) >>> M = np.random.rand(500, 32)
>>> cons1 = ConstraintScalar(ConstraintName(ConstraintName.SPLIN), 500, 32, 5) >>> cons1 = ConstraintInt(ConstraintName(ConstraintName.SPLIN), 500, 32, 5)
>>> cons2 = ConstraintScalar(ConstraintName(ConstraintName.SP), 32, 32, 96) >>> cons2 = ConstraintInt(ConstraintName(ConstraintName.SP), 32, 32, 96)
>>> cons3 = ConstraintScalar(ConstraintName(ConstraintName.SP), 32, 32, 96) >>> cons3 = ConstraintInt(ConstraintName(ConstraintName.SP), 32, 32, 96)
>>> cons4 = ConstraintScalar(ConstraintName(ConstraintName.NORMCOL), 32, 32, 1) >>> cons4 = ConstraintReal(ConstraintName(ConstraintName.NORMCOL), 32, 32, 1)
>>> cons5 = ConstraintScalar(ConstraintName(ConstraintName.SP), 32, 32, 666) >>> cons5 = ConstraintInt(ConstraintName(ConstraintName.SP), 32, 32, 666)
>>> cons6 = ConstraintScalar(ConstraintName(ConstraintName.SP), 32, 32, 333) >>> cons6 = ConstraintInt(ConstraintName(ConstraintName.SP), 32, 32, 333)
>>> stop_crit1 = StoppingCriterion(num_its=200) >>> stop_crit1 = StoppingCriterion(num_its=200)
>>> stop_crit2 = StoppingCriterion(num_its=200) >>> stop_crit2 = StoppingCriterion(num_its=200)
>>> param = ParamsHierarchicalFact(num_facts, is_update_way_R2L, init_lambda, >>> param = ParamsHierarchicalFact(num_facts, is_update_way_R2L, init_lambda,
...@@ -1076,8 +1076,11 @@ class ParamsFact(object): ...@@ -1076,8 +1076,11 @@ class ParamsFact(object):
self.constraints = constraints self.constraints = constraints
self.is_verbose = is_verbose self.is_verbose = is_verbose
self.constant_step_size = constant_step_size self.constant_step_size = constant_step_size
#TODO: raise exception if num_facts != len(init_facts) if(not isinstance(init_facts, list) and not isinstance(init_facts,
#TODO: likewise for constraints tuple) or
len(init_facts) != num_facts):
raise ValueError('ParamsFact init_facts argument must a be '
'list/tuple of '+str(num_facts)+" (num_facts) arguments.")
class ParamsPalm4MSA(ParamsFact): class ParamsPalm4MSA(ParamsFact):
...@@ -1089,8 +1092,11 @@ class ParamsPalm4MSA(ParamsFact): ...@@ -1089,8 +1092,11 @@ class ParamsPalm4MSA(ParamsFact):
constraints, step_size, constraints, step_size,
constant_step_size, constant_step_size,
is_verbose) is_verbose)
if(not isinstance(stop_crit,StoppingCriterion)):
raise TypeError('ParamsPalm4MSA stop_crit argument must be a StoppingCriterion '
'object')
self.stop_crit = stop_crit self.stop_crit = stop_crit
# TODO: raise exception if stop_crit not a StoppingCriterion object #TODO: verify number of constraints is consistent with num_facts
class ParamsHierarchicalFact(ParamsFact): class ParamsHierarchicalFact(ParamsFact):
...@@ -1107,8 +1113,20 @@ class ParamsHierarchicalFact(ParamsFact): ...@@ -1107,8 +1113,20 @@ class ParamsHierarchicalFact(ParamsFact):
self.data_num_cols = data_num_cols self.data_num_cols = data_num_cols
self.stop_crits = stop_crits self.stop_crits = stop_crits
self.is_fact_side_left = is_fact_side_left self.is_fact_side_left = is_fact_side_left
# TODO: raise exception if stop_crits len != 2 or not StoppingCriterion #TODO: verify number of constraints is consistent with num_facts in
# objects if((not isinstance(stop_crits, list) and not isinstance(stop_crits,
tuple)) or
len(stop_crits) != 2 or
not isinstance(stop_crits[0],StoppingCriterion) or not
isinstance(stop_crits[1],StoppingCriterion)):
raise TypeError('ParamsHierarchicalFact stop_crits argument must be a list/tuple of two '
'StoppingCriterion objects')
if((not isinstance(constraints, list) and not isinstance(constraints,
tuple)) or
np.array([not isinstance(constraints[i],ConstraintGeneric) for i in
range(0,len(constraints))]).any()):
raise TypeError('constraints argument must be a list/tuple of '
'ConstraintGeneric (or subclasses) objects')
class StoppingCriterion(object): class StoppingCriterion(object):
...@@ -1134,7 +1152,7 @@ class ConstraintName: ...@@ -1134,7 +1152,7 @@ class ConstraintName:
NORMLIN = 9 # Real Constraint NORMLIN = 9 # Real Constraint
def __init__(self, name): def __init__(self, name):
if(name < ConstraintName.SP or name > ConstraintName.NORMLIN): if(not isinstance(name, np.int) or name < ConstraintName.SP or name > ConstraintName.NORMLIN):
raise ValueError("name must be an integer among ConstraintName.SP," raise ValueError("name must be an integer among ConstraintName.SP,"
"ConstraintName.SPCOL, ConstraintName.NORMCOL," "ConstraintName.SPCOL, ConstraintName.NORMCOL,"
"ConstraintName.SPLINCOL, ConstraintName.CONST," "ConstraintName.SPLINCOL, ConstraintName.CONST,"
...@@ -1155,10 +1173,11 @@ class ConstraintName: ...@@ -1155,10 +1173,11 @@ class ConstraintName:
class ConstraintGeneric(object): class ConstraintGeneric(object):
def __init__(self, name , num_rows, num_cols): def __init__(self, name, num_rows, num_cols, param):
self.value = name self.value = name
self.num_rows = num_rows self.num_rows = num_rows
self.num_cols = num_cols self.num_cols = num_cols
self.param = param
@property @property
def name(self): def name(self):
...@@ -1173,15 +1192,42 @@ class ConstraintGeneric(object): ...@@ -1173,15 +1192,42 @@ class ConstraintGeneric(object):
def is_mat_constraint(self): def is_mat_constraint(self):
return self.value.is_mat_constraint() return self.value.is_mat_constraint()
class ConstraintScalar(ConstraintGeneric): class ConstraintReal(ConstraintGeneric):
def __init__(self, name, num_rows, num_cols, param):
super(ConstraintReal, self).__init__(name, num_rows, num_cols, param)
if(not isinstance(param, np.float) and not isinstance(param, np.int)):
raise TypeError('ConstraintReal must receive a float as param '
'argument.')
self.param = float(self.param)
if(not isinstance(name, ConstraintName) or not name.is_real_constraint()):
raise TypeError('ConstraintReal first argument must be a '
'ConstraintName with a real type name '
'(name.is_real_constraint() must return True).')
class ConstraintInt(ConstraintGeneric):
def __init__(self, name, num_rows, num_cols, param): def __init__(self, name, num_rows, num_cols, param):
super(ConstraintScalar, self).__init__(name, num_rows, num_cols) super(ConstraintInt, self).__init__(name, num_rows, num_cols, param)
self.param = param # raise value error if not np.number (can be complex if(not isinstance(param, np.int)):
# or float/double or integer raise TypeError('CosntraintInt must receive a int as param '
'argument.')
if(not isinstance(name, ConstraintName) or not name.is_int_constraint()):
raise TypeError('ConstraintInt first argument must be a '
'ConstraintName with a int type name '
'(name.is_int_constraint() must return True).')
class ConstraintMat(ConstraintGeneric): class ConstraintMat(ConstraintGeneric):
def __init__(self, name, num_rows, num_cols, param): def __init__(self, name, num_rows, num_cols, param):
super(ConstraintMat, self).__init__(name, num_rows, num_cols) super(ConstraintMat, self).__init__(name, num_rows, num_cols, param)
self.param = param #TODO: raise ValueError if not np.ndarray if(not isinstance(param, np.matrix) and not isinstance(param,
np.ndarray)):
raise TypeError('ConstraintMat must receive a numpy matrix as param '
'argument.')
self.param = float(self.param)
if(not isinstance(name, ConstraintName) or not name.is_mat_constraint()):
raise TypeError('ConstraintMat first argument must be a '
'ConstraintName with a matrix type name'
'(name.is_mat_constraint() must return True).')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment