Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 3544a9c5 authored by hhakim's avatar hhakim
Browse files

Test also the recursive method of pyfaust.concatenate in testConcatenate unit test.

parent 68ba738d
No related branches found
No related tags found
No related merge requests found
Pipeline #873266 passed
......@@ -781,36 +781,38 @@ class TestFaustPy(unittest.TestCase):
self.assertEqual(C.shape, ref_C.shape)
self.assertLessEqual(norm(C.toarray()-ref_C)/norm(ref_C),
10**-5)
# test random number of Fausts concatenation
# test concatenation of a random number of Fausts
from pyfaust import rand as frand, concatenate as cat, isFaust
for cat_axis in [0, 1]:
n = self.r.randint(3, 18)
fausts = []
field_names = ['real', 'complex']
fac_types = ['sparse', 'dense', 'mixed']
for i in range(n):
fac_type_id = self.r.randint(0,2)
field_id = self.r.randint(0,1)
is_faust = bool(self.r.randint(0,1)) or i == 0
nrows = self.r.randint(2, 128) if cat_axis == 0 else F.shape[0]
ncols = self.r.randint(2, 128) if cat_axis == 1 else F.shape[1]
if is_faust:
fausts += [frand(nrows, ncols,
fac_type=fac_types[fac_type_id],
field=field_names[field_id])]
else:
fausts += [frand(nrows, ncols, fac_type=fac_types[fac_type_id],
field=field_names[field_id],
num_factors=1).factors(0)]
Fc = cat(tuple(fausts), axis=cat_axis)
arrays = []
for F in fausts:
if not isinstance(F, np.ndarray):
arrays += [F.toarray()]
else:
arrays += [F]
Mc = np.concatenate(arrays, axis=cat_axis)
self.assertTrue(np.allclose(Fc.toarray(), Mc))
for iterative in [True, False]:
# iterative == False means recursive
for cat_axis in [0, 1]:
n = self.r.randint(3, 18)
fausts = []
field_names = ['real', 'complex']
fac_types = ['sparse', 'dense', 'mixed']
for i in range(n):
fac_type_id = self.r.randint(0,2)
field_id = self.r.randint(0,1)
is_faust = bool(self.r.randint(0,1)) or i == 0
nrows = self.r.randint(2, 128) if cat_axis == 0 else F.shape[0]
ncols = self.r.randint(2, 128) if cat_axis == 1 else F.shape[1]
if is_faust:
fausts += [frand(nrows, ncols,
fac_type=fac_types[fac_type_id],
field=field_names[field_id])]
else:
fausts += [frand(nrows, ncols, fac_type=fac_types[fac_type_id],
field=field_names[field_id],
num_factors=1).factors(0)]
Fc = cat(tuple(fausts), axis=cat_axis, iterative=iterative)
arrays = []
for F in fausts:
if not isinstance(F, np.ndarray):
arrays += [F.toarray()]
else:
arrays += [F]
Mc = np.concatenate(arrays, axis=cat_axis)
self.assertTrue(np.allclose(Fc.toarray(), Mc))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment