Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 17faaefb authored by hhakim's avatar hhakim
Browse files

Make pyfaust LazyLinearOp * a synonym of @ except if the operand is a scalar...

Make pyfaust LazyLinearOp * a synonym of @ except if the operand is a scalar (following scipy LinearOperator behaviour).
parent cd73786a
Branches
Tags
No related merge requests found
...@@ -1297,7 +1297,17 @@ class LazyLinearOp2(LinearOperator): ...@@ -1297,7 +1297,17 @@ class LazyLinearOp2(LinearOperator):
self._checkattr('__matmul__') self._checkattr('__matmul__')
if not hasattr(op, 'shape'): if not hasattr(op, 'shape'):
raise TypeError('op must have a shape attribute') raise TypeError('op must have a shape attribute')
if swap and op.shape[1] != self.shape[0] or not swap and self.shape[1] != op.shape[0]: if not hasattr(op, 'ndim'):
raise TypeError('op must have a ndim attribute')
if op.ndim == 1 and (self.shape[0] if swap else self.shape[1]) != op.size or op.ndim == 2 and (swap and
op.shape[1]
!=
self.shape[0]
or not
swap and
self.shape[1]
!=
op.shape[0]):
raise ValueError('dimensions must agree') raise ValueError('dimensions must agree')
def __matmul__(self, op): def __matmul__(self, op):
...@@ -1435,15 +1445,9 @@ class LazyLinearOp2(LinearOperator): ...@@ -1435,15 +1445,9 @@ class LazyLinearOp2(LinearOperator):
if np.isscalar(other): if np.isscalar(other):
S = eye(self.shape[1], format='csr') * other S = eye(self.shape[1], format='csr') * other
lop = LazyLinearOp2.create_from_op(S) lop = LazyLinearOp2.create_from_op(S)
elif hasattr(other, 'ndim') and other.ndim == 1 or other.ndim == 2 and other.shape[0] == 1: new_op = self @ lop
if other.size == 1:
return self * self.item()
else:
D = diags(other.squeeze())
lop = LazyLinearOp2.create_from_op(D)
else: else:
raise TypeError('Unsupported type for other') new_op = self @ other
new_op = self @ lop
return new_op return new_op
def __rmul__(self, s): def __rmul__(self, s):
...@@ -1456,7 +1460,10 @@ class LazyLinearOp2(LinearOperator): ...@@ -1456,7 +1460,10 @@ class LazyLinearOp2(LinearOperator):
""" """
# because s is a scalar, it is commutative # because s is a scalar, it is commutative
# for vector broadcasting too # for vector broadcasting too
return self * s if np.isscalar(s):
return self * s
else:
return s @ self
def __imul__(self, op): def __imul__(self, op):
......
...@@ -123,14 +123,15 @@ class TestLazyLinearOpFaust(unittest.TestCase): ...@@ -123,14 +123,15 @@ class TestLazyLinearOpFaust(unittest.TestCase):
def test_mul(self): def test_mul(self):
v = np.random.rand(self.lop.shape[1]) v = np.random.rand(self.lop.shape[1])
lmul2 = self.lop * v lmul2 = self.lop * v
self.assertTrue(isinstance(lmul2, LazyLinearOp2)) self.assertTrue(isinstance(lmul2, np.ndarray))
self.assertAlmostEqual(LA.norm(lmul2.toarray() - (self.lopA * v)), self.assertAlmostEqual(LA.norm(lmul2 - (self.lopA @ v)),
0) 0)
v = np.random.rand(1, self.lop.shape[1]) v = np.random.rand(self.lop.shape[1], 1)
lmul2 = self.lop * v lmul2 = self.lop * v
self.assertTrue(isinstance(lmul2, LazyLinearOp2)) self.assertTrue(isinstance(lmul2, np.ndarray))
self.assertAlmostEqual(LA.norm(lmul2.toarray() - (self.lopA * v)), self.assertAlmostEqual(LA.norm(lmul2 - (self.lopA @ v)),
0) 0)
s = np.random.rand(1, 1)[0, 0] s = np.random.rand(1, 1)[0, 0]
lmul2 = self.lop * s lmul2 = self.lop * s
self.assertTrue(isinstance(lmul2, LazyLinearOp2)) self.assertTrue(isinstance(lmul2, LazyLinearOp2))
...@@ -138,19 +139,20 @@ class TestLazyLinearOpFaust(unittest.TestCase): ...@@ -138,19 +139,20 @@ class TestLazyLinearOpFaust(unittest.TestCase):
0) 0)
def test_rmul(self): def test_rmul(self):
v = np.random.rand(self.lop.shape[1]) v = np.random.rand(self.lop.shape[0])
lmul2 = v * self.lop lmul2 = v * self.lop
self.assertTrue(isinstance(lmul2, LazyLinearOp2)) self.assertTrue(isinstance(lmul2, np.ndarray))
self.assertAlmostEqual(LA.norm(lmul2.toarray() - (v * self.lopA)), self.assertAlmostEqual(LA.norm(lmul2 - (v @ self.lopA)),
0) 0)
v = np.random.rand(1, self.lop.shape[1]) v = np.random.rand(1, self.lop.shape[0])
lmul2 = v * self.lop lmul2 = v @ self.lop
self.assertTrue(isinstance(lmul2, LazyLinearOp2)) self.assertTrue(isinstance(lmul2, np.ndarray))
self.assertAlmostEqual(LA.norm(lmul2.toarray() - (v * self.lopA)), self.assertAlmostEqual(LA.norm(lmul2 - (v @ self.lopA)),
0) 0)
s = np.random.rand(1, 1)[0, 0] s = np.random.rand(1, 1)[0, 0]
self.assertTrue(np.isscalar(s))
lmul2 = s * self.lop lmul2 = s * self.lop
self.assertTrue(isinstance(lmul2, LazyLinearOp2)) self.assertTrue(isinstance(lmul2, LazyLinearOp2))
self.assertAlmostEqual(LA.norm(lmul2.toarray() - (s * self.lopA)), self.assertAlmostEqual(LA.norm(lmul2.toarray() - (s * self.lopA)),
...@@ -193,11 +195,10 @@ class TestLazyLinearOpFaust(unittest.TestCase): ...@@ -193,11 +195,10 @@ class TestLazyLinearOpFaust(unittest.TestCase):
def test_chain_ops(self): def test_chain_ops(self):
lchain = self.lop + self.lop2 lchain = self.lop + self.lop2
lchain = lchain @ self.lop3 lchain = lchain @ self.lop3
lchain = 2 * lchain lchain = 2 * lchain * 3
v = np.random.rand(lchain.shape[1]) self.assertTrue(np.allclose(lchain.toarray(), 6 * (self.lopA + self.lop2A) @ self.lop3A))
lchain = lchain * v
lchain = lchain.concatenate(self.lop3, axis=0) lchain = lchain.concatenate(self.lop3, axis=0)
mat_ref = np.vstack(((2 * (self.lopA + self.lop2A) @ self.lop3A) * v, mat_ref = np.vstack(((2 * (self.lopA + self.lop2A) @ self.lop3A * 3),
self.lop3A)) self.lop3A))
self.assertAlmostEqual(LA.norm(lchain.toarray() - mat_ref), self.assertAlmostEqual(LA.norm(lchain.toarray() - mat_ref),
0) 0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment