Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 8e507563 authored by hhakim's avatar hhakim
Browse files

Add a unit test for pyfaust.Faust.concatenate().

parent 67c6974a
Branches
No related tags found
No related merge requests found
...@@ -278,6 +278,37 @@ class TestFaustPy(unittest.TestCase): ...@@ -278,6 +278,37 @@ class TestFaustPy(unittest.TestCase):
#TODO: test mul by a complex scalar when impl. #TODO: test mul by a complex scalar when impl.
def testConcatenate(self):
from pyfaust import FaustFactory
F = self.F
#if(F.dtype == np.complex): return
for cat_axis in [0,1]:
G = \
FaustFactory.rand(self.r.randint(1,TestFaustPy.MAX_NUM_FACTORS),
F.shape[(cat_axis+1)%2],
is_real=not isinstance(self,TestFaustPyCplx))
# add one random factor to get a random number of rows to test
# vertcat and a random number of cols to test horzcat
if cat_axis == 0:
M = sparse.csr_matrix(np.random.rand(
self.r.randint(1,TestFaustPy.MAX_DIM_SIZE),
F.shape[(cat_axis+1)%2]).astype(F.dtype))
H = Faust([M]+[G.get_factor(i) for i in
range(0,G.get_num_factors())])
else:
M = sparse.csr_matrix(np.random.rand(
F.shape[(cat_axis+1)%2],
self.r.randint(1,TestFaustPy.MAX_DIM_SIZE)).astype(F.dtype))
H = Faust([G.get_factor(i) for i in
range(0,G.get_num_factors())]+[M])
print("testConcatenate() F.shape, H.shape", F.shape, H.shape)
C = F.concatenate(H,axis=cat_axis)
ref_C = np.concatenate((F.toarray(),
H.toarray()),
axis=cat_axis)
self.assertLessEqual(np.linalg.norm(C.toarray()-ref_C)/norm(ref_C),
10**-5)
def testTranspose(self): def testTranspose(self):
print("testTranspose()") print("testTranspose()")
tFaust = self.F.transpose() tFaust = self.F.transpose()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment