Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 4ac0c4f5 authored by hhakim's avatar hhakim
Browse files

Add pyfaust.lazylinop.block_diag with documentation.

[skip ci]
parent 5c585c63
Branches
Tags
No related merge requests found
......@@ -1252,7 +1252,7 @@ def diag(v, k=0):
def sum(*lops, mt=False, af=False):
"""
Sums (lazily) all LazyLinearOp-s in lops.
Sums (lazily) all linear operators in lops.
Args:
lops: the objects to add up as a list of LazyLinearOp-s or other compatible linear operator.
......@@ -1327,3 +1327,81 @@ def sum(*lops, mt=False, af=False):
return S
return LazyLinearOperator(lops[0].shape, matmat=lambda x: matmat(x, lAx),
rmatmat=lambda x: matmat(x, lAHx))
def block_diag(*lops, mt=False):
"""
Returns the block diagonal LazyLinearOp formed of operators in lops.
Args:
lops: the objects defining the diagonal blocks as a list of LazyLinearOp-s or other compatible linear operator.
mt: True to active the multithread experimental mode (not advisable, so
far it's not faster than sequential execution).
Returns:
The LazyLinearOp for the sum of lops.
Example:
>>> import numpy as np
>>> from pyfaust.lazylinop import block_diag, aslazylinearoperator
>>> from scipy.sparse import diags
>>> nt = 10
>>> d = 64
>>> v = np.random.rand(d)
>>> terms = [np.random.rand(64, 64) for _ in range(10)]
>>> ls = block_diag(*terms) # ls is the block diagonal LazyLinearOperator
<b>See also:</b> scipy.linalg.block_diag
"""
lAx = lambda A, x: A @ x
lAHx = lambda A, x: A.T.conj() @ x
offsets = [0]
for i in range(len(lops)):
offsets += [offsets[i] + lops[i].shape[1]]
def matmat(x, lmul):
from threading import Thread
from multiprocessing import cpu_count
Ps = [None for _ in range(len(lops))]
n = len(lops)
class Mul(Thread):
def __init__(self, As, x, out, i):
self.As = As
self.x = x
self.out = out
self.i = i
super(Mul, self).__init__(target=self.run)
def run(self):
for i, A in enumerate(self.As):
ipi = self.i + i
self.out[ipi] = lmul(A,
self.x[offsets[ipi]:offsets[ipi+1]])
if mt:
ths = []
nths = min(cpu_count(), n)
share = [n // nths for _ in range(nths)]
rem = n - share[0] * nths
if rem > 0:
while rem > 0:
share[rem-1] += 1
rem -= 1
for i in range(1, len(share)):
share[i] += share[i-1]
share = [0] + share
for i in range(nths):
start = share[i]
end = share[i+1]
ths += [Mul(lops[start:end], x, Ps, start)]
ths[-1].start()
for i in range(nths):
ths[i].join()
else:
for i, A in enumerate(lops):
Ps[i] = lmul(A, x[offsets[i]:offsets[i+1]])
S = Ps[0]
for i in range(1, n):
S = np.vstack((S, Ps[i]))
return S
return LazyLinearOperator((np.sum([A.shape[0] for A in lops]), np.sum([A.shape[1] for A in lops])), matmat=lambda x: matmat(x, lAx),
rmatmat=lambda x: matmat(x, lAHx))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment