Mentions légales du service

Skip to content
Snippets Groups Projects
Commit f6c005cd authored by hhakim's avatar hhakim
Browse files

Add pyfaust.lazylinop.kron and LazyLinearOpKron associated class.

parent a4665048
Branches
Tags
No related merge requests found
......@@ -120,8 +120,8 @@ class LazyLinearOp(LinearOperator):
Returns the LazyLinearOp transpose.
"""
self._checkattr('transpose')
new_op = self.__class__(init_lambda=lambda:
(self._lambda_stack()).transpose(),
new_op = LazyLinearOp(init_lambda=lambda:
(self.eval()).transpose(),
shape=(self.shape[1], self.shape[0]),
root_obj=self._root_obj)
return new_op
......@@ -138,8 +138,8 @@ class LazyLinearOp(LinearOperator):
Returns the LazyLinearOp conjugate.
"""
self._checkattr('conj')
new_op = self.__class__(init_lambda=lambda:
(self._lambda_stack()).conj(),
new_op = LazyLinearOp(init_lambda=lambda:
(self.eval()).conj(),
shape=self.shape,
root_obj=self._root_obj)
return new_op
......@@ -155,8 +155,8 @@ class LazyLinearOp(LinearOperator):
Returns the LazyLinearOp adjoint/transconjugate.
"""
self._checkattr('getH')
new_op = self.__class__(init_lambda=lambda:
(self._lambda_stack()).getH(),
new_op = LazyLinearOp(init_lambda=lambda:
(self.eval()).getH(),
shape=(self.shape[1], self.shape[0]),
root_obj=self._root_obj)
return new_op
......@@ -183,8 +183,8 @@ class LazyLinearOp(LinearOperator):
"""
self._checkattr('__add__')
new_op = self.__class__(init_lambda=lambda:
self._lambda_stack() + LazyLinearOp._eval_if_lazy(op),
new_op = LazyLinearOp(init_lambda=lambda:
self.eval() + LazyLinearOp._eval_if_lazy(op),
shape=(tuple(self.shape)),
root_obj=self._root_obj)
return new_op
......@@ -203,11 +203,11 @@ class LazyLinearOp(LinearOperator):
"""
Not Implemented self += op.
"""
raise NotImplementedError(self.__class__.__name__+".__iadd__")
raise NotImplementedError(LazyLinearOp.__name__+".__iadd__")
# can't do as follows, it recurses indefinitely because of self.eval
# self._checkattr('__iadd__')
# self = self.__class__(init_lambda=lambda:
# (self._lambda_stack()).\
# self = LazyLinearOp(init_lambda=lambda:
# (self.eval()).\
# __iadd__(LazyLinearOp._eval_if_lazy(op)),
# shape=(tuple(self.shape)),
# root_obj=self._root_obj)
......@@ -223,7 +223,7 @@ class LazyLinearOp(LinearOperator):
"""
self._checkattr('__sub__')
new_op = self.__class__(init_lambda=lambda: self._lambda_stack() - LazyLinearOp._eval_if_lazy(op),
new_op = LazyLinearOp(init_lambda=lambda: self.eval() - LazyLinearOp._eval_if_lazy(op),
shape=(tuple(self.shape)),
root_obj=self._root_obj)
return new_op
......@@ -237,9 +237,9 @@ class LazyLinearOp(LinearOperator):
"""
self._checkattr('__rsub__')
new_op = self.__class__(init_lambda=lambda:
new_op = LazyLinearOp(init_lambda=lambda:
LazyLinearOp._eval_if_lazy(op) -
self._lambda_stack(),
self.eval(),
shape=(tuple(self.shape)),
root_obj=self._root_obj)
return new_op
......@@ -248,11 +248,11 @@ class LazyLinearOp(LinearOperator):
"""
Not implemented self -= op.
"""
raise NotImplementedError(self.__class__.__name__+".__isub__")
raise NotImplementedError(LazyLinearOp.__name__+".__isub__")
# can't do as follows, it recurses indefinitely because of self.eval
# self._checkattr('__isub__')
# self = self.__class__(init_lambda=lambda:
# (self._lambda_stack()).\
# self = LazyLinearOp(init_lambda=lambda:
# (self.eval()).\
# __isub__(LazyLinearOp._eval_if_lazy(op)),
# shape=(tuple(self.shape)),
# root_obj=self._root_obj)
......@@ -268,8 +268,8 @@ class LazyLinearOp(LinearOperator):
"""
self._checkattr('__truediv__')
new_op = self.__class__(init_lambda=lambda:
self._lambda_stack() / LazyLinearOp._eval_if_lazy(op),
new_op = LazyLinearOp(init_lambda=lambda:
self.eval() / LazyLinearOp._eval_if_lazy(op),
shape=(tuple(self.shape)),
root_obj=self._root_obj)
return new_op
......@@ -278,12 +278,12 @@ class LazyLinearOp(LinearOperator):
"""
Not implemented self /= op.
"""
raise NotImplementedError(self.__class__.__name__+".__itruediv__")
raise NotImplementedError(LazyLinearOp.__name__+".__itruediv__")
# can't do as follows, it recurses indefinitely because of self.eval
#
# self._checkattr('__itruediv__')
# self = self.__class__(init_lambda=lambda:
# (self._lambda_stack()).\
# self = LazyLinearOp(init_lambda=lambda:
# (self.eval()).\
# __itruediv__(LazyLinearOp._eval_if_lazy(op)),
# shape=(tuple(self.shape)),
# root_obj=self._root_obj)
......@@ -302,13 +302,13 @@ class LazyLinearOp(LinearOperator):
raise TypeError('op must have a shape attribute')
if self.shape[1] != op.shape[0]:
raise ValueError('dimensions must agree')
if isinstance(op, LazyLinearOp):
res = self.__class__(init_lambda=lambda:
self.eval() @ op.eval(),
if isinstance(op, np.ndarray):
res = self.eval() @ op
else:
res = LazyLinearOp(init_lambda=lambda:
self.eval() @ LazyLinearOp._eval_if_lazy(op),
shape=(self.shape[0], op.shape[1]),
root_obj=self._root_obj)
else:
res = self.eval() @ op
return res
def dot(self, op):
......@@ -367,11 +367,11 @@ class LazyLinearOp(LinearOperator):
"""
Not implemented self @= op.
"""
raise NotImplementedError(self.__class__.__name__+".__imatmul__")
raise NotImplementedError(LazyLinearOp.__name__+".__imatmul__")
# can't do as follows, it recurses indefinitely because of self.eval
# self._checkattr('__imatmul__')
# self = self.__class__(init_lambda=lambda:
# (self._lambda_stack()).\
# self = LazyLinearOp(init_lambda=lambda:
# (self.eval()).\
# __imatmul__(LazyLinearOp._eval_if_lazy(op)),
# shape=(tuple(self.shape)),
# root_obj=self._root_obj)
......@@ -391,7 +391,7 @@ class LazyLinearOp(LinearOperator):
if self.shape[0] != op.shape[1]:
raise ValueError('dimensions must agree')
if isinstance(op, LazyLinearOp):
res = self.__class__(init_lambda=lambda:
res = LazyLinearOp(init_lambda=lambda:
op.eval() @ self.eval(),
shape=(self.shape[0], op.shape[1]),
root_obj=self._root_obj)
......@@ -414,8 +414,8 @@ class LazyLinearOp(LinearOperator):
self.shape == op.shape or \
op.shape[0] == 1 and op.shape[1] == self.shape[1] or \
op.shape[1] == 1 and op.shape[0] == self.shape[0]:
new_op = self.__class__(init_lambda=lambda:
self._lambda_stack() * LazyLinearOp._eval_if_lazy(op),
new_op = LazyLinearOp(init_lambda=lambda:
self.eval() * LazyLinearOp._eval_if_lazy(op),
shape=(tuple(self.shape)),
root_obj=self._root_obj)
return new_op
......@@ -436,9 +436,9 @@ class LazyLinearOp(LinearOperator):
op.shape[0] == 1 and op.shape[1] == self.shape[1] or \
op.shape[1] == 1 and op.shape[0] == self.shape[0]:
self._checkattr('__rmul__')
new_op = self.__class__(init_lambda=lambda:
new_op = LazyLinearOp(init_lambda=lambda:
LazyLinearOp._eval_if_lazy(op) *
self._lambda_stack(),
self.eval(),
shape=(tuple(self.shape)),
root_obj=self._root_obj)
return new_op
......@@ -449,11 +449,11 @@ class LazyLinearOp(LinearOperator):
"""
Not implemented self *= op.
"""
raise NotImplementedError(self.__class__.__name__+".__imul__")
raise NotImplementedError(LazyLinearOp.__name__+".__imul__")
# # can't do as follows, it recurses indefinitely because of self.eval
# self._checkattr('__imul__')
# self = self.__class__(init_lambda=lambda:
# (self._lambda_stack()).\
# self = LazyLinearOp(init_lambda=lambda:
# (self.eval()).\
# __imul__(LazyLinearOp._eval_if_lazy(op)),
# shape=(tuple(self.shape)),
# root_obj=self._root_obj)
......@@ -483,8 +483,8 @@ class LazyLinearOp(LinearOperator):
if isinstance(indices, tuple) and len(indices) == 2 and isinstance(indices[0], int) and isinstance(indices[1], int):
return self.eval().__getitem__(indices)
else:
return self.__class__(init_lambda=lambda:
(self._lambda_stack()).\
return LazyLinearOp(init_lambda=lambda:
(self.eval()).\
__getitem__(indices),
shape=self._newshape_getitem(indices),
root_obj=self._root_obj)
......@@ -582,8 +582,8 @@ class LazyLinearOp(LinearOperator):
for op in ops:
ncols += op.shape[1]
new_shape = (nrows, ncols)
new_op = self.__class__(init_lambda=lambda:
cat((self._lambda_stack(),
new_op = LazyLinearOp(init_lambda=lambda:
cat((self.eval(),
*[LazyLinearOp._eval_if_lazy(op) for op in
ops]), axis=axis),
shape=(new_shape),
......@@ -597,8 +597,8 @@ class LazyLinearOp(LinearOperator):
Returns the LazyLinearOp for real.
"""
self._checkattr('real')
new_op = self.__class__(init_lambda=lambda:
(self._lambda_stack()).real,
new_op = LazyLinearOp(init_lambda=lambda:
(self.eval()).real,
shape=self.shape,
root_obj=self._root_obj)
return new_op
......@@ -609,8 +609,8 @@ class LazyLinearOp(LinearOperator):
Returns the LazyLinearOp for imag.
"""
self._checkattr('imag')
new_op = self.__class__(init_lambda=lambda:
(self._lambda_stack()).imag,
new_op = LazyLinearOp(init_lambda=lambda:
(self.eval()).imag,
shape=self.shape,
root_obj=self._root_obj)
return new_op
......@@ -735,3 +735,60 @@ def vstack(tup):
return lop.concatenate(*tup[1:], axis=0)
else:
raise TypeError('lop must be a LazyLinearOp')
def kron(A, B):
return LazyLinearOpKron(lambda: A, A, B)
class LazyLinearOpKron(LazyLinearOp):
def __init__(self, init_lambda, A, B):
self.A = A
self.B = B
shape = (A.shape[0] * B.shape[0], A.shape[1] * B.shape[1])
super(LazyLinearOpKron, self).__init__(init_lambda, shape, A)
def conj(self):
return LazyLinearOpKron(self._lambda_stack, self.A.conj(), self.B.conj())
def transpose(self):
return LazyLinearOpKron(self._lambda_stack, self.A.T, self.B.T)
def getH(self):
return LazyLinearOpKron(self._lambda_stack, self.A.getH(), self.B.getH())
def __matmul__(self, op):
#TODO: refactor with parent function
self._checkattr('__matmul__')
if not hasattr(op, 'shape'):
raise TypeError('op must have a shape attribute')
if self.shape[1] != op.shape[0]:
raise ValueError('dimensions must agree')
if hasattr(op, 'reshape') and hasattr(op, '__matmul__') and hasattr(op,
'__getitem__'):
if op.ndim == 1:
op = op.reshape((op.size, 1))
one_dim = True
else:
one_dim = False
res = np.empty((self.shape[0], op.shape[1]))
for j in range(op.shape[1]):
op_mat = op[:, j].reshape((self.A.shape[1], self.B.shape[1]))
res[:, j] = (LazyLinearOp._eval_if_lazy(self.A) @ op_mat @
LazyLinearOp._eval_if_lazy(self.B).T).reshape(self.shape[0])
if one_dim:
res = res.ravel()
else:
res = LazyLinearOp(init_lambda=lambda:
self.eval() @ LazyLinearOp._eval_if_lazy(op),
shape=(self.shape[0], op.shape[1]),
root_obj=self._root_obj)
return res
def eval(self):
A = self.A
B = self.B
if not isinstance(A, np.ndarray):
A = A.toarray()
if not isinstance(B, np.ndarray):
B = B.toarray()
return np.kron(A, B)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment