Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 494a35c5 authored by hhakim's avatar hhakim
Browse files

Fix pyfaust.LazyLinearOp.concatenate resulting shape and complete the corresponding unit test.

parent 924edd60
No related branches found
No related tags found
No related merge requests found
Pipeline #834118 skipped
...@@ -423,10 +423,12 @@ class LazyLinearOp: ...@@ -423,10 +423,12 @@ class LazyLinearOp:
axis: axis of concatenation (0 for rows, 1 for columns). axis: axis of concatenation (0 for rows, 1 for columns).
""" """
from pyfaust import concatenate as cat from pyfaust import concatenate as cat
new_shape = (self.shape[0] + op.shape[0] if axis == 0 else self.shape[0],
self.shape[1] + op.shape[1] if axis == 1 else self.shape[1])
new_op = self.__class__(init_lambda=lambda: new_op = self.__class__(init_lambda=lambda:
cat((self._lambda_stack(), cat((self._lambda_stack(),
LazyLinearOp._eval_if_lazy(op)), axis=axis), LazyLinearOp._eval_if_lazy(op)), axis=axis),
shape=(tuple(self.shape)), shape=(new_shape),
root_obj=self._root_obj) root_obj=self._root_obj)
return new_op return new_op
......
...@@ -140,10 +140,19 @@ class TestLazyLinearOpFaust(unittest.TestCase): ...@@ -140,10 +140,19 @@ class TestLazyLinearOpFaust(unittest.TestCase):
self.assertAlmostEqual(LA.norm(lcat.toarray() - np.vstack((self.lopA, self.assertAlmostEqual(LA.norm(lcat.toarray() - np.vstack((self.lopA,
self.lop2A))), self.lop2A))),
0) 0)
self.assertEqual(lcat.shape[0], self.lop.shape[0] + self.lop2.shape[0])
lcat = self.lop.concatenate(self.lop2, axis=1) lcat = self.lop.concatenate(self.lop2, axis=1)
self.assertAlmostEqual(LA.norm(lcat.toarray() - np.hstack((self.lopA, self.assertAlmostEqual(LA.norm(lcat.toarray() - np.hstack((self.lopA,
self.lop2A))), self.lop2A))),
0) 0)
self.assertEqual(lcat.shape[1], self.lop.shape[1] + self.lop2.shape[1])
# auto concat
lcat = self.lop.concatenate(self.lop, axis=0)
self.assertAlmostEqual(LA.norm(lcat.toarray() - np.vstack((self.lopA,
self.lopA))),
0)
self.assertEqual(lcat.shape[0], self.lop.shape[0] + self.lop.shape[0])
def test_chain_ops(self): def test_chain_ops(self):
lchain = self.lop + self.lop2 lchain = self.lop + self.lop2
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment