Mentions légales du service

Skip to content
Snippets Groups Projects
Commit ba74007b authored by hhakim's avatar hhakim
Browse files

Fix dtype issue #298: lazylinop mul functions in complex case when a numpy...

Fix dtype issue #298: lazylinop mul functions in complex case when a numpy array is initialized with np.empty (float64 default dtype).
parent 05093b5d
No related branches found
No related tags found
No related merge requests found
Pipeline #834145 skipped
...@@ -116,6 +116,28 @@ def vstack(tup): ...@@ -116,6 +116,28 @@ def vstack(tup):
else: else:
raise TypeError('lop must be a LazyLinearOp') raise TypeError('lop must be a LazyLinearOp')
def _binary_dtype(A_dtype, B_dtype):
if isinstance(A_dtype, str):
A_dtype = np.dtype(A_dtype)
if isinstance(B_dtype, str):
B_dtype = np.dtype(B_dtype)
if A_dtype is None:
return B_dtype
if B_dtype is None:
return A_dtype
if A_dtype is None and B_dtype is None:
return None
kinds = [A_dtype.kind, B_dtype.kind]
if A_dtype.kind == B_dtype.kind:
dtype = A_dtype if A_dtype.itemsize > B_dtype.itemsize else B_dtype
elif 'c' in [A_dtype.kind, B_dtype.kind]:
dtype = 'complex'
elif 'f' in kinds:
dtype = 'double'
else:
dtype = A_dtype
return dtype
class LazyLinearOp(LinearOperator): class LazyLinearOp(LinearOperator):
""" """
This class implements a lazy linear operator. A LazyLinearOp is a This class implements a lazy linear operator. A LazyLinearOp is a
...@@ -512,7 +534,9 @@ class LazyLinearOp(LinearOperator): ...@@ -512,7 +534,9 @@ class LazyLinearOp(LinearOperator):
elif op.ndim > 2: elif op.ndim > 2:
from itertools import product from itertools import product
# op.ndim > 2 # op.ndim > 2
res = np.empty((*op.shape[:-2], self.shape[0], op.shape[-1])) dtype = _binary_dtype(self.dtype, op.dtype)
res = np.empty((*op.shape[:-2], self.shape[0], op.shape[-1]),
dtype=dtype)
idl = [ list(range(op.shape[i])) for i in range(op.ndim-2) ] idl = [ list(range(op.shape[i])) for i in range(op.ndim-2) ]
for t in product(*idl): for t in product(*idl):
tr = (*t, slice(0, res.shape[-2]), slice(0, res.shape[-1])) tr = (*t, slice(0, res.shape[-2]), slice(0, res.shape[-1]))
...@@ -1026,12 +1050,14 @@ def LazyLinearOperator(shape, **kwargs): ...@@ -1026,12 +1050,14 @@ def LazyLinearOperator(shape, **kwargs):
' passed in kwargs.') ' passed in kwargs.')
def _matmat(M, _matvec): def _matmat(M, _matvec):
nonlocal dtype
if M.ndim == 1: if M.ndim == 1:
return _matvec(M) return _matvec(M)
first_col = _matvec(M[:, 0])
out = np.empty((shape[0], M.shape[1]), dtype=dtype if dtype is not None dtype = first_col.dtype
else M.dtype) out = np.empty((shape[0], M.shape[1]), dtype=dtype)
for i in range(M.shape[1]): out[:, 0] = first_col
for i in range(1, M.shape[1]):
out[:, i] = _matvec(M[:,i]) out[:, i] = _matvec(M[:,i])
return out return out
...@@ -1096,7 +1122,8 @@ def kron(A, B): ...@@ -1096,7 +1122,8 @@ def kron(A, B):
one_dim = True one_dim = True
else: else:
one_dim = False one_dim = False
res = np.empty((shape[0], op.shape[1])) dtype = _binary_dtype(A.dtype, B.dtype)
res = np.empty((shape[0], op.shape[1]), dtype=dtype)
def out_col(j, ncols): def out_col(j, ncols):
for j in range(j, min(j + ncols, op.shape[1])): for j in range(j, min(j + ncols, op.shape[1])):
op_mat = op[:, j].reshape((A.shape[1], B.shape[1])) op_mat = op[:, j].reshape((A.shape[1], B.shape[1]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment