Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 12f81cc6 authored by hhakim's avatar hhakim
Browse files

Review pyfaust.factparams.ParamsPalm4MSAFGFT._set_init_D and add its unit test.

parent 3bde5773
Branches
Tags
No related merge requests found
...@@ -1581,7 +1581,7 @@ def fgft_palm(U, Lap, p, init_D=None, ret_lambda=False, ret_params=False): ...@@ -1581,7 +1581,7 @@ def fgft_palm(U, Lap, p, init_D=None, ret_lambda=False, ret_params=False):
U: (numpy.ndarray) U: (numpy.ndarray)
The Fourier matrix. The Fourier matrix.
init_D: (numpy.ndarray) init_D: (numpy.ndarray)
The initial diagonal vector. if None it will be the ones() vector by default. The initial diagonal vector. if None it will be vector ones() by default.
p: (ParamsHierarchical) p: (ParamsHierarchical)
The PALM hierarchical algorithm parameters. The PALM hierarchical algorithm parameters.
ret_lambda: (bool) ret_lambda: (bool)
......
...@@ -2085,23 +2085,22 @@ class ParamsPalm4MSAFGFT(ParamsPalm4MSA): ...@@ -2085,23 +2085,22 @@ class ParamsPalm4MSAFGFT(ParamsPalm4MSA):
@staticmethod @staticmethod
def _set_init_D(init_D, dim_sz): def _set_init_D(init_D, dim_sz):
""" """
Utility function for ParamsHierarchicalFGFT, ParamsPalm4MSAFGFT Utility function for ParamsHierarchicalFGFT, ParamsPalm4MSAFGFT.
""" """
def _check_init_D_is_consistent(init_D): def _check_init_D_is_consistent(init_D):
if(not isinstance(init_D, np.ndarray)): if not isinstance(init_D, np.ndarray):
raise ValueError("init_D must be a numpy ndarray") raise TypeError("init_D must be a numpy ndarray")
if(init_D.ndim != 1): if init_D.ndim != 1:
raise ValueError("init_D must be a vector.") raise ValueError("init_D must be a vector.")
if(init_D.shape[0] != dim_sz): if init_D.shape[0] != dim_sz:
raise ValueError("init_D must have the same size as first " raise ValueError("init_D must have the same size as first "
"constraint's number of rows") "constraint number of rows")
if not isinstance(init_D, np.ndarray): if init_D is None:
# default init_D (ones)
init_D = np.ones(dim_sz) init_D = np.ones(dim_sz)
_check_init_D_is_consistent(init_D) _check_init_D_is_consistent(init_D)
return init_D return init_D
else:
return init_D
class ParamsPalm4msaWHT(ParamsPalm4MSA): class ParamsPalm4msaWHT(ParamsPalm4MSA):
""" """
......
import unittest import unittest
from pyfaust.factparams import ConstraintName from pyfaust.factparams import ConstraintName
from pyfaust.factparams import ParamsFact from pyfaust.factparams import ParamsFact, ParamsPalm4MSAFGFT
import numpy as np
class TestFactParams(unittest.TestCase): class TestFactParams(unittest.TestCase):
def __init__(self, methodName='runTest', dev='cpu', dtype='double'): def __init__(self, methodName='runTest', dev='cpu', dtype='double'):
super(TestFactParams, self).__init__(methodName) super(TestFactParams, self).__init__(methodName)
def test_int2str_str2int(self): def test_int2str_str2int(self):
print("ConstraintName.name_int2str & name_str2int") print("ConstraintName.name_int2str & name_str2int")
max_int_name = 0 max_int_name = 0
...@@ -18,8 +16,9 @@ class TestFactParams(unittest.TestCase): ...@@ -18,8 +16,9 @@ class TestFactParams(unittest.TestCase):
if isinstance(ConstraintName.__dict__[attr], int): if isinstance(ConstraintName.__dict__[attr], int):
str_name = attr.lower().replace('_', '').replace('blkdiag', str_name = attr.lower().replace('_', '').replace('blkdiag',
'blockdiag') 'blockdiag')
self.assertEqual(ConstraintName.name_int2str(ConstraintName.__dict__[attr]), self.assertEqual(ConstraintName.name_int2str(
str_name) ConstraintName.__dict__[attr]),
str_name)
self.assertEqual(ConstraintName.__dict__[attr], self.assertEqual(ConstraintName.__dict__[attr],
ConstraintName.str2name_int(str_name)) ConstraintName.str2name_int(str_name))
if ConstraintName.__dict__[attr] > max_int_name: if ConstraintName.__dict__[attr] > max_int_name:
...@@ -31,6 +30,7 @@ class TestFactParams(unittest.TestCase): ...@@ -31,6 +30,7 @@ class TestFactParams(unittest.TestCase):
ConstraintName.name_int2str, max_int_name+1) ConstraintName.name_int2str, max_int_name+1)
def test_factor_format_str2int_int2str(self): def test_factor_format_str2int_int2str(self):
print("ParamsFact.factor_format_int2str/str2int")
formats = ['dense', 'sparse', 'dynamic'] formats = ['dense', 'sparse', 'dynamic']
for i, s in enumerate(formats): for i, s in enumerate(formats):
self.assertEqual(ParamsFact.factor_format_int2str(i), s) self.assertEqual(ParamsFact.factor_format_int2str(i), s)
...@@ -48,8 +48,32 @@ class TestFactParams(unittest.TestCase): ...@@ -48,8 +48,32 @@ class TestFactParams(unittest.TestCase):
self.assertRaisesRegex(ValueError, int_range_error_msg, self.assertRaisesRegex(ValueError, int_range_error_msg,
ParamsFact.factor_format_int2str, 3) ParamsFact.factor_format_int2str, 3)
str_range_error_msg = "factor_format as str must be in " + \ str_range_error_msg = "factor_format as str must be in " + \
repr(formats).replace('[', '\[').replace(']', '\]') repr(formats).replace('[', r'\[').replace(']', r'\]')
self.assertRaisesRegex(ValueError, str_range_error_msg, self.assertRaisesRegex(ValueError, str_range_error_msg,
ParamsFact.factor_format_str2int, 'anyformat') ParamsFact.factor_format_str2int, 'anyformat')
self.assertRaisesRegex(ValueError, str_range_error_msg, self.assertRaisesRegex(ValueError, str_range_error_msg,
ParamsFact.factor_format_int2str, 'anyformat') ParamsFact.factor_format_int2str, 'anyformat')
def test_ParamsPalm4MSAFGFT_set_init_D(self):
print("ParamsPalm4MSAFGFT._set_init_D")
# check errors are detected
not_nparr_err = "init_D must be a numpy ndarray"
not_vec_err = "init_D must be a vector."
invalid_sz_err = "init_D must have the same size as first "
"constraint number of rows"
self.assertRaisesRegex(TypeError, not_nparr_err,
ParamsPalm4MSAFGFT._set_init_D, 'anything', 10)
self.assertRaisesRegex(ValueError, not_vec_err,
ParamsPalm4MSAFGFT._set_init_D, np.ones((10,
10)),
10)
self.assertRaisesRegex(ValueError, invalid_sz_err,
ParamsPalm4MSAFGFT._set_init_D, np.ones((11)),
10)
# check None is accepted and give default vector of ones
self.assertTrue(np.allclose(ParamsPalm4MSAFGFT._set_init_D(None, 10),
np.ones((10))))
# check a proper vector is accepted (i.e. returned)
v = np.random.rand(10)
self.assertTrue(np.allclose(ParamsPalm4MSAFGFT._set_init_D(v, 10),
v))
...@@ -46,9 +46,9 @@ cdef class FaustAlgoGen@TYPE_NAME@: ...@@ -46,9 +46,9 @@ cdef class FaustAlgoGen@TYPE_NAME@:
cpp_stop_crit.max_num_its = p.stop_crit.maxiter cpp_stop_crit.max_num_its = p.stop_crit.maxiter
cpp_stop_crit.erreps = p.stop_crit.erreps cpp_stop_crit.erreps = p.stop_crit.erreps
calling_fft_algo = isinstance(init_D, np.ndarray) calling_fgft_algo = isinstance(init_D, np.ndarray)
if(not p.init_facts): if not p.init_facts:
p.init_facts = [ None for i in range(p.num_facts) ] p.init_facts = [ None for i in range(p.num_facts) ]
if(p.is_update_way_R2L): if(p.is_update_way_R2L):
zeros_id = p.num_facts-1 zeros_id = p.num_facts-1
...@@ -62,7 +62,7 @@ cdef class FaustAlgoGen@TYPE_NAME@: ...@@ -62,7 +62,7 @@ cdef class FaustAlgoGen@TYPE_NAME@:
p.constraints[i]._num_cols, order='F', p.constraints[i]._num_cols, order='F',
dtype=M.dtype) dtype=M.dtype)
if(calling_fft_algo): if calling_fgft_algo:
# FFT/FGFT case, we store lambda in first position and the diagonal # FFT/FGFT case, we store lambda in first position and the diagonal
# of D in the next # of D in the next
_out_buf = np.empty(init_D.shape[0]+1, dtype=M.dtype) _out_buf = np.empty(init_D.shape[0]+1, dtype=M.dtype)
...@@ -70,7 +70,7 @@ cdef class FaustAlgoGen@TYPE_NAME@: ...@@ -70,7 +70,7 @@ cdef class FaustAlgoGen@TYPE_NAME@:
# store only lambda as a return from Palm4MSA algo # store only lambda as a return from Palm4MSA algo
_out_buf = np.array([0], dtype=M.dtype) _out_buf = np.array([0], dtype=M.dtype)
if(calling_fft_algo): if calling_fgft_algo:
cpp_params = new \ cpp_params = new \
FaustCoreCy.PyxParamsFactPalm4MSAFFT[@TYPE@,double]() FaustCoreCy.PyxParamsFactPalm4MSAFFT[@TYPE@,double]()
init_D_view = init_D init_D_view = init_D
...@@ -144,7 +144,7 @@ cdef class FaustAlgoGen@TYPE_NAME@: ...@@ -144,7 +144,7 @@ cdef class FaustAlgoGen@TYPE_NAME@:
core = @CORE_CLASS@(core=True) core = @CORE_CLASS@(core=True)
if(calling_fft_algo): if calling_fgft_algo:
core.@CORE_OBJ@ = \ core.@CORE_OBJ@ = \
FaustCoreCy.fact_palm4MSAFFT[@TYPE@,double](&Mview[0,0], FaustCoreCy.fact_palm4MSAFFT[@TYPE@,double](&Mview[0,0],
M_num_rows, M_num_rows,
...@@ -165,7 +165,7 @@ cdef class FaustAlgoGen@TYPE_NAME@: ...@@ -165,7 +165,7 @@ cdef class FaustAlgoGen@TYPE_NAME@:
del cpp_params del cpp_params
if(calling_fft_algo): if calling_fgft_algo:
return core, np.real(_out_buf[0]), _out_buf[1:] return core, np.real(_out_buf[0]), _out_buf[1:]
else: else:
return core, np.real(_out_buf[0]) return core, np.real(_out_buf[0])
...@@ -217,9 +217,9 @@ cdef class FaustAlgoGen@TYPE_NAME@: ...@@ -217,9 +217,9 @@ cdef class FaustAlgoGen@TYPE_NAME@:
cpp_stop_crits[1].erreps = p.stop_crits[1].erreps cpp_stop_crits[1].erreps = p.stop_crits[1].erreps
calling_fft_algo = isinstance(init_D, np.ndarray) calling_fgft_algo = isinstance(init_D, np.ndarray)
if(calling_fft_algo): if calling_fgft_algo:
# FFT/FGFT case, we store lambda in first position and the diagonal # FFT/FGFT case, we store lambda in first position and the diagonal
# of D in the next # of D in the next
_out_buf = np.empty(init_D.shape[0]+1, dtype=M.dtype) _out_buf = np.empty(init_D.shape[0]+1, dtype=M.dtype)
...@@ -227,7 +227,7 @@ cdef class FaustAlgoGen@TYPE_NAME@: ...@@ -227,7 +227,7 @@ cdef class FaustAlgoGen@TYPE_NAME@:
# store only lambda as a return from Palm4MSA algo # store only lambda as a return from Palm4MSA algo
_out_buf = np.array([0], dtype=M.dtype) _out_buf = np.array([0], dtype=M.dtype)
if(calling_fft_algo): if calling_fgft_algo:
cpp_params = new \ cpp_params = new \
FaustCoreCy.PyxParamsHierarchicalFactFFT[@TYPE@,double]() FaustCoreCy.PyxParamsHierarchicalFactFFT[@TYPE@,double]()
init_D_view = init_D init_D_view = init_D
...@@ -292,7 +292,7 @@ cdef class FaustAlgoGen@TYPE_NAME@: ...@@ -292,7 +292,7 @@ cdef class FaustAlgoGen@TYPE_NAME@:
cpp_params.num_constraints = len(p.constraints) cpp_params.num_constraints = len(p.constraints)
core = @CORE_CLASS@(core=True) core = @CORE_CLASS@(core=True)
if(calling_fft_algo): if calling_fgft_algo:
core.@CORE_OBJ@ = \ core.@CORE_OBJ@ = \
FaustCoreCy.fact_hierarchical_fft[@TYPE@, FaustCoreCy.fact_hierarchical_fft[@TYPE@,
double](&Mview[0,0], double](&Mview[0,0],
...@@ -314,7 +314,7 @@ cdef class FaustAlgoGen@TYPE_NAME@: ...@@ -314,7 +314,7 @@ cdef class FaustAlgoGen@TYPE_NAME@:
if(core.@CORE_OBJ@ == NULL): raise Exception("fact_hierarchical" if(core.@CORE_OBJ@ == NULL): raise Exception("fact_hierarchical"
" has failed."); " has failed.");
if(calling_fft_algo): if calling_fgft_algo:
return core, np.real(_out_buf[0]), _out_buf[1:] return core, np.real(_out_buf[0]), _out_buf[1:]
else: else:
return core, np.real(_out_buf[0]) return core, np.real(_out_buf[0])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment