Mentions légales du service

Skip to content
Snippets Groups Projects
Commit b106c986 authored by hhakim's avatar hhakim
Browse files

Optimize pyfaust/matfaust.poly.expm_multiply: avoid recomputing T*B for each tau.

parent 17061642
No related branches found
No related tags found
No related merge requests found
......@@ -89,6 +89,9 @@ function C = expm_multiply(A, B, t, varargin)
n = size(B, 2);
npts = numel(t);
Y = zeros(npts, m, n);
if (poly_meth == 2)
TB = T*B;
end
for i=1:npts
tau = t(i);
if(tau >= 0)
......@@ -102,7 +105,6 @@ function C = expm_multiply(A, B, t, varargin)
end
coeff(1) = coeff(1)*.5;
if(poly_meth == 2)
TB = T*B;
Y(i, :, :) = matfaust.poly.poly(coeff, TB, 'dev', dev);
elseif(poly_meth == 1)
Y(i, :, :) = matfaust.poly.poly(coeff, T, 'X', B, 'dev', dev);
......
......@@ -515,9 +515,10 @@ def expm_multiply(A, B, t, K=10, dev='cpu', **kwargs):
raise ValueError('A must be symmetric positive definite.')
poly_meth = 1
if 'poly_meth' in kwargs:
poly_meth = kwargs['poly_meth']
if poly_meth not in [1, 2]:
if kwargs['poly_meth'] not in [1, 2, '1', '2']:
raise ValueError('poly_meth must be 1 or 2')
poly_meth = int(kwargs['poly_meth'])
print("expm_multiply poly_meth:", poly_meth)
phi = eigsh(A, k=1, return_eigenvectors=False)[0] / 2
T = basis(A/phi-seye(*A.shape), K, 'chebyshev', dev=dev)
if isinstance(t, float):
......@@ -529,6 +530,8 @@ def expm_multiply(A, B, t, K=10, dev='cpu', **kwargs):
n = B.shape[1]
npts = len(t)
Y = empty((npts, m, n))
if poly_meth == 2:
TB = np.squeeze(T@B)
for i,tau in enumerate(t):
if tau >= 0:
raise ValueError('pyfaust.poly.expm_multiply handles only negative '
......@@ -541,7 +544,6 @@ def expm_multiply(A, B, t, K=10, dev='cpu', **kwargs):
coeff[j] = coeff[j+2] - (2 * j + 2) / (-tau * phi) * coeff[j+1]
coeff[0] /= 2
if poly_meth == 2:
TB = np.squeeze(T@B)
if n == 1:
Y[i,:,0] = np.squeeze(poly(coeff, TB, dev=dev))
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment