Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 042dc39e authored by hhakim's avatar hhakim
Browse files

Parallelize pyfaust.lazylinop.LazyLinearOpKron.__matmul__ when multiplying a X...

Parallelize pyfaust.lazylinop.LazyLinearOpKron.__matmul__ when multiplying a X with multiple columns and add an option (disabled by default) to force the evaluation (just for comparison).
parent 925ea7a8
Branches
Tags 3.35.5
No related merge requests found
Pipeline #834131 skipped
......@@ -820,22 +820,56 @@ class LazyLinearOpKron(LazyLinearOp):
<b>See also:</b> pyfaust.lazylinop.kron.
"""
from threading import Thread
from multiprocessing import cpu_count
from os import environ
from pyfaust import isFaust
self._sanitize_matmul(op)
force_eval = False # find no case where it is worth to be True
if 'KRON_MATMUL_FORCE_EVAL' in environ:
force_eval = environ['KRON_MATMUL_FORCE_EVAL'] == '1'
if hasattr(op, 'reshape') and hasattr(op, '__matmul__') and hasattr(op,
'__getitem__'):
if op.ndim == 1:
op = op.reshape((op.size, 1))
one_dim = True
if force_eval:
res = self.eval() @ op
else:
one_dim = False
res = np.empty((self.shape[0], op.shape[1]))
for j in range(op.shape[1]):
op_mat = op[:, j].reshape((self.A.shape[1], self.B.shape[1]))
res[:, j] = (LazyLinearOp._eval_if_lazy(self.A) @ op_mat @
LazyLinearOp._eval_if_lazy(self.B).T).reshape(self.shape[0])
if one_dim:
res = res.ravel()
if isFaust(self.B) or isFaust(self.B):
parallel = False # e.g. for A, B Fausts in R^100x100 and op 128 columns
# it was found that the sequential computation was faster
else:
parallel = True
if 'KRON_PARALLEL' in environ:
parallel = environ['KRON_PARALLEL'] == '1'
nthreads = cpu_count() // 2
if op.ndim == 1:
op = op.reshape((op.size, 1))
one_dim = True
else:
one_dim = False
res = np.empty((self.shape[0], op.shape[1]))
def out_col(j, ncols):
for j in range(j, min(j + ncols, op.shape[1])):
op_mat = op[:, j].reshape((self.A.shape[1], self.B.shape[1]))
res[:, j] = (LazyLinearOp._eval_if_lazy(self.A) @ op_mat @
LazyLinearOp._eval_if_lazy(self.B).T).reshape(self.shape[0])
ncols = op.shape[1]
if parallel:
t = []
cols_per_thread = ncols // nthreads
if cols_per_thread * nthreads < ncols:
cols_per_thread += 1
while len(t) < nthreads:
t.append(Thread(target=out_col, args=(cols_per_thread *
len(t),
cols_per_thread)))
t[-1].start()
for j in range(nthreads):
t[j].join()
else:
out_col(0, ncols)
if one_dim:
res = res.ravel()
else:
res = LazyLinearOp(init_lambda=lambda:
self.eval() @ LazyLinearOp._eval_if_lazy(op),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment