Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 049e884b authored by hhakim's avatar hhakim
Browse files

Add pyfaust LazyLinearOp support of matmul by a multidimensional numpy array (nd greater than 2).

parent eda6ac6a
No related branches found
No related tags found
No related merge requests found
......@@ -444,15 +444,16 @@ class LazyLinearOp(LinearOperator):
raise TypeError('op must have a shape attribute')
if not hasattr(op, 'ndim'):
raise TypeError('op must have a ndim attribute')
if op.ndim == 1 and (self.shape[0] if swap else self.shape[1]) != op.size or op.ndim == 2 and (swap and
op.shape[1]
!=
self.shape[0]
or not
swap and
self.shape[1]
!=
op.shape[0]):
if op.ndim == 1 and (self.shape[0] if swap else self.shape[1]) != \
op.size or op.ndim >= 2 and (swap and
op.shape[-1]
!=
self.shape[0]
or not
swap and
self.shape[1]
!=
op.shape[-2]):
raise ValueError('dimensions must agree')
def __matmul__(self, op):
......@@ -473,6 +474,17 @@ class LazyLinearOp(LinearOperator):
if isinstance(op, np.ndarray) or issparse(op):
if op.ndim == 1 and self._root_obj is not None:
res = self.lambdas['@'](op.reshape(op.size, 1)).ravel()
elif op.ndim > 2:
from itertools import product
# op.ndim > 2
res = np.empty((*op.shape[:-2], self.shape[0], op.shape[-1]))
idl = [ list(range(op.shape[i])) for i in range(op.ndim-2) ]
for t in product(*idl):
tr = (*t, slice(0, res.shape[-2]), slice(0, res.shape[-1]))
to = (*t, slice(0, op.shape[-2]), slice(0, op.shape[-1]))
R = self.lambdas['@'](op.__getitem__(to))
res.__setitem__(tr, R)
# TODO: try to parallelize
else:
res = self.lambdas['@'](op)
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment