Mentions légales du service

Skip to content
Snippets Groups Projects
Commit cf26c3fc authored by hhakim's avatar hhakim
Browse files

Update pyfaust.lazylinop.diag.

- Handle x.ndim == 2 in multiplication.
- Handle x is a LazyLinearOp.
parent 8af7dc12
Branches
Tags
No related merge requests found
...@@ -1136,18 +1136,37 @@ def eye(m, n=None, k = 0, dtype='float'): ...@@ -1136,18 +1136,37 @@ def eye(m, n=None, k = 0, dtype='float'):
def diag(v, k=0): def diag(v, k=0):
""" """
Construct a diagonal LazyLinearOp.
Args:
v: a 1-D numpy array for the diagonal.
k: (int) the index of digonal, 0 for the main diagonal, k>0 for diagonals
above, k<0 for diagonals below.
Returns:
The diagonal LazyLinearOperator.
""" """
if v.ndim > 1 or v.ndim == 0: if v.ndim > 1 or v.ndim == 0:
raise ValueError("v must be a 1-dim vector.") raise ValueError("v must be a 1-dim vector.")
m = v.size + abs(k) m = v.size + abs(k)
def matmat(x, v, k): def matmat(x, v, k):
if k > 0: v = v.reshape(v.size, 1)
y = v * x[k:k+v.size] if x.ndim == 1:
y = np.hstack((y, np.zeros(k))) x_is_1d = True
elif k < 0: x = x.reshape(x.size, 1)
y = v * x[:v.size] else:
y = np.hstack((np.zeros(abs(k)), y)) x_is_1d = False
# TODO: if x is LazyLinearOp, do a naive np.diag(v) @ x if isLazyLinearOp(x):
y = np.diag(v, k) @ x
else:
if k > 0:
y = v * x[k:k+v.size]
y = np.vstack((y, np.zeros((k, x.shape[1]))))
elif k < 0:
y = v * x[:v.size]
y = np.vstack((np.zeros((abs(k), x.shape[1])), y))
if x_is_1d:
y = y.ravel()
return y return y
return LazyLinearOperator((m, m), matmat=lambda x: matmat(x, v, k), return LazyLinearOperator((m, m), matmat=lambda x: matmat(x, v, k),
rmatmat=lambda x: matmat(x, v, -k)) rmatmat=lambda x: matmat(x, v, -k))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment