Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 71857df5 authored by hhakim's avatar hhakim
Browse files

Update matfaust.lazylinop multiplication by scalar to be less costly.

parent 3cde8cdf
Branches
No related tags found
No related merge requests found
...@@ -165,7 +165,7 @@ classdef LazyLinearOp < handle % needed to use references on objects ...@@ -165,7 +165,7 @@ classdef LazyLinearOp < handle % needed to use references on objects
%> @param op: an object compatible with self for this binary operation. %> @param op: an object compatible with self for this binary operation.
%============================================================= %=============================================================
function LM = mtimes(L, A) function LM = mtimes(L, A)
import matfaust.lazylinop.LazyLinearOp import matfaust.lazylinop.*
if LazyLinearOp.isLazyLinearOp(A) if LazyLinearOp.isLazyLinearOp(A)
if isscalar(L) if isscalar(L)
LM = mtimes(A, L); LM = mtimes(A, L);
...@@ -176,24 +176,19 @@ classdef LazyLinearOp < handle % needed to use references on objects ...@@ -176,24 +176,19 @@ classdef LazyLinearOp < handle % needed to use references on objects
end end
end end
check_meth(L, 'mtimes'); check_meth(L, 'mtimes');
op_is_scalar = all(size(A) == [1, 1]); A_is_scalar = all(size(A) == [1, 1]);
if ~ op_is_scalar && ~ all(size(L, 2) == size(A, 1)) if ~ A_is_scalar && ~ all(size(L, 2) == size(A, 1))
error('Dimensions must agree') error('Dimensions must agree')
end end
if op_is_scalar
new_size = size(L);
else
new_size = [size(L, 1), size(A, 2)];
end
function l = mul_index_lambda(L, A, S) function LMI = mul_index_lambda(L, A, S)
% L and A must be LazyLinearOp % L and A must be LazyLinearOp
import matfaust.lazylinop.LazyLinearOp import matfaust.lazylinop.LazyLinearOp
Sr.type = '()'; Sr.type = '()';
Sr.subs = {S.subs{1}, ':'}; Sr.subs = {S.subs{1}, ':'};
Sc.type = '()'; Sc.type = '()';
Sc.subs = {':', S.subs{2}}; Sc.subs = {':', S.subs{2}};
L.lambdas{L.I}(Sr) * A.lambdas{L.I}(Sc); LMI = L.lambdas{L.I}(Sr) * A.lambdas{L.I}(Sc);
end end
if ~ LazyLinearOp.isLazyLinearOp(A) && ismatrix(A) && isnumeric(A) && any(size(A) ~= [1, 1]) if ~ LazyLinearOp.isLazyLinearOp(A) && ismatrix(A) && isnumeric(A) && any(size(A) ~= [1, 1])
...@@ -201,17 +196,19 @@ classdef LazyLinearOp < handle % needed to use references on objects ...@@ -201,17 +196,19 @@ classdef LazyLinearOp < handle % needed to use references on objects
LM = L.lambdas{L.MUL}(A); LM = L.lambdas{L.MUL}(A);
else else
if isscalar(A) if isscalar(A)
LM_size = size(L); matmat = @(M) M * A;
LM = L * LazyLinearOperator([L.size(2), L.size(2)], 'matmat', matmat, 'rmatmat', matmat);
return;
else else
if ~ LazyLinearOp.isLazyLinearOp(A) if ~ LazyLinearOp.isLazyLinearOp(A)
A = LazyLinearOp.create_from_op(A); A = LazyLinearOp.create_from_op(A);
end end
LM_size = [size(L, 1), size(A, 2)] LM_size = [size(L, 1), size(A, 2)];
end end
lambdas = {@(o) L * (A * o), ... %MUL lambdas = {@(o) L * (A * o), ... %MUL
@() A.' * L.', ... % T @() A.' * L.', ... % T
@() A' * L', ... % H @() A' * L', ... % H
@(S) mul_index_lambda(L, A, S)% I @(S) mul_index_lambda(L, A, S)% I
}; };
LM = LazyLinearOp(lambdas, LM_size); LM = LazyLinearOp(lambdas, LM_size);
......
...@@ -590,11 +590,11 @@ class LazyLinearOp(LinearOperator): ...@@ -590,11 +590,11 @@ class LazyLinearOp(LinearOperator):
<b>See also:</b> pyfaust.lazylinop.LazyLinearOp.__matmul__) <b>See also:</b> pyfaust.lazylinop.LazyLinearOp.__matmul__)
""" """
self._checkattr('__mul__') self._checkattr('__mul__')
from scipy.sparse import eye
if np.isscalar(other): if np.isscalar(other):
S = eye(self.shape[1], format='csr') * other Dshape = (self.shape[1], self.shape[1])
lop = LazyLinearOp.create_from_op(S) matmat = lambda M: M * other
new_op = self @ lop D = LazyLinearOperator(Dshape, matmat=matmat, rmatmat=matmat)
new_op = self @ D
else: else:
new_op = self @ other new_op = self @ other
return new_op return new_op
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment