Mentions légales du service

Skip to content
Snippets Groups Projects
Commit c2a291f8 authored by hhakim's avatar hhakim
Browse files

Add __repr__ functions to StoppingCriterion, ParamsHierarchical/Palm4MSA,...

Add __repr__ functions to StoppingCriterion, ParamsHierarchical/Palm4MSA, ConstraintGen and function ConstraintName.name_int2str.
parent d105bb5f
Branches
No related tags found
No related merge requests found
...@@ -116,6 +116,12 @@ class ConstraintGeneric(ABC): ...@@ -116,6 +116,12 @@ class ConstraintGeneric(ABC):
if(M.shape[0] != self._num_rows or M.shape[1] != self._num_cols): if(M.shape[0] != self._num_rows or M.shape[1] != self._num_cols):
raise ValueError("The dimensions must agree.") raise ValueError("The dimensions must agree.")
def __repr__(self):
return self._name.name_str()+"("+str(self._num_rows)+","+ \
str(self._num_cols) + (", "+str(self._cons_value) + ")" if not
self.is_mat_constraint()
else ")")
class ConstraintInt(ConstraintGeneric): class ConstraintInt(ConstraintGeneric):
""" """
This class represents an integer constraint on a matrix. This class represents an integer constraint on a matrix.
...@@ -410,6 +416,49 @@ class ConstraintName: ...@@ -410,6 +416,49 @@ class ConstraintName:
ConstraintName.CIRC, ConstraintName.TOEPLITZ, ConstraintName.CIRC, ConstraintName.TOEPLITZ,
ConstraintName.HANKEL, ConstraintName.BLKDIAG] ConstraintName.HANKEL, ConstraintName.BLKDIAG]
def name_str(self):
return ConstraintName.name_int2str(self.name)
@staticmethod
def name_int2str(_id):
"""
Converts a int constraint short name to its str constant name equivalent.
For example, name_int2str(ConstraintName.SP) returns 'sp'.
"""
err_msg = "Invalid argument to designate a ConstraintName."
if(not isinstance(_id, int)):
raise ValueError(err_msg)
if(_id == ConstraintName.SP):
_str = 'sp'
elif(_id == ConstraintName.SPLIN):
_str = 'splin'
elif(_id == ConstraintName.SPCOL):
_str = 'spcol'
elif(_id == ConstraintName.SPLINCOL):
_str = 'splincol'
elif(_id == ConstraintName.SP_POS):
_str = 'sppos'
elif(_id == ConstraintName.NORMCOL):
_str = 'normcol'
elif(_id == ConstraintName.NORMLIN):
_str = 'normlin'
elif(_id == ConstraintName.SUPP):
_str = 'supp'
elif(_id == ConstraintName.CONST):
_str = 'const'
elif(_id == ConstraintName.CIRC):
_str = 'circ'
elif(_id == ConstraintName.TOEPLITZ):
_str = 'toeplitz'
elif(_id == ConstraintName.HANKEL):
_str = 'hankel'
elif(_id == ConstraintName.BLKDIAG):
_str = 'blockdiag'
else:
raise ValueError(err_msg)
return _str
@staticmethod @staticmethod
def str2name_int(_str): def str2name_int(_str):
""" """
...@@ -578,6 +627,23 @@ class ParamsFact(ABC): ...@@ -578,6 +627,23 @@ class ParamsFact(ABC):
self.use_csr = use_csr self.use_csr = use_csr
self.packing_RL = packing_RL self.packing_RL = packing_RL
def __repr__(self):
"""
Returns object representation.
"""
return ("num_facts="+str( self.num_facts)+'\r\n'
"is_update_way_R2L="+str( self.is_update_way_R2L)+'\r\n'
"init_lambda="+str( self.init_lambda)+'\r\n'
"step_size="+str( self.step_size)+'\r\n'
"constant_step_size="+str( self.constant_step_size)+'\r\n'
"grad_calc_opt_mode="+str( self.grad_calc_opt_mode)+'\r\n'
"norm2_max_iter="+str( self.norm2_max_iter)+'\r\n'
"norm2_threshold="+str( self.norm2_threshold)+'\r\n'
"use_csr="+str( self.use_csr)+'\r\n'
"packing_RL="+str( self.packing_RL)+'\r\n'
"is_verbose="+str( self.is_verbose)+'\r\n'
"constraints="+str( self.constraints))+'\r\n'
@abstractmethod @abstractmethod
def is_mat_consistent(self, M): def is_mat_consistent(self, M):
if(not isinstance(M, np.ndarray)): if(not isinstance(M, np.ndarray)):
...@@ -722,7 +788,15 @@ class ParamsHierarchical(ParamsFact): ...@@ -722,7 +788,15 @@ class ParamsHierarchical(ParamsFact):
if(not isinstance(M, np.ndarray)): if(not isinstance(M, np.ndarray)):
raise ValueError("M must be a numpy ndarray") raise ValueError("M must be a numpy ndarray")
return M.shape[0] == self.data_num_rows and \ return M.shape[0] == self.data_num_rows and \
M.shape[1] == self.data_num_cols M.shape[0] == self.data_num_cols
def __repr__(self):
"""
Returns object representation.
"""
return super(ParamsHierarchical, self).__repr__()+ \
"local stopping criterion: "+str(self.stop_crits[0])+"\r\n" \
"global stopping criterion"+str(self.stop_crits[1])
class ParamsHierarchicalSquareMat(ParamsHierarchical): class ParamsHierarchicalSquareMat(ParamsHierarchical):
""" """
...@@ -928,6 +1002,10 @@ class ParamsPalm4MSA(ParamsFact): ...@@ -928,6 +1002,10 @@ class ParamsPalm4MSA(ParamsFact):
def is_mat_consistent(self, M): def is_mat_consistent(self, M):
return super(ParamsPalm4MSA, self).is_mat_consistent(M) return super(ParamsPalm4MSA, self).is_mat_consistent(M)
def __repr__(self):
return super(ParamsPalm4MSA, self).__repr__()+ \
"stopping criterion: "+str(self.stop_crit)
class ParamsPalm4MSAFGFT(ParamsPalm4MSA): class ParamsPalm4MSAFGFT(ParamsPalm4MSA):
""" """
""" """
...@@ -1053,6 +1131,12 @@ class StoppingCriterion(object): ...@@ -1053,6 +1131,12 @@ class StoppingCriterion(object):
return "num_its: "+str(self.num_its)+ \ return "num_its: "+str(self.num_its)+ \
", maxiter: " + str(self.maxiter) ", maxiter: " + str(self.maxiter)
def __repr__(self):
"""
Returns the StoppingCriterion object representation.
"""
return self.__str__()
class ParamsFactFactory: class ParamsFactFactory:
""" """
The factory for creating simplified FAuST hierarchical algorithm parameters (ParamsHierarchical). The factory for creating simplified FAuST hierarchical algorithm parameters (ParamsHierarchical).
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment