Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 4a4e7c62 authored by hhakim's avatar hhakim
Browse files

Modify pyfaust.poly to be able to handle a L indifferently whether it's a csr_matrix or a Faust.

parent bd9d5dde
Branches
Tags 3.0.15
No related merge requests found
Pipeline #833990 skipped
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
import scipy.sparse as sp import scipy.sparse as sp
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
from pyfaust import Faust, isFaust from pyfaust import (Faust, isFaust, eye as feye, vstack as fvstack, hstack as
fhstack)
from scipy.sparse.linalg import eigsh from scipy.sparse.linalg import eigsh
...@@ -26,60 +27,28 @@ def Chebyshev(L, K, ret_gen=False, dev='cpu', T0=None): ...@@ -26,60 +27,28 @@ def Chebyshev(L, K, ret_gen=False, dev='cpu', T0=None):
Returns: Returns:
The Faust of the K+1 Chebyshev polynomials. The Faust of the K+1 Chebyshev polynomials.
""" """
if not isinstance(L, csr_matrix): if not isinstance(L, csr_matrix) and not isFaust(L):
L = csr_matrix(L) L = csr_matrix(L)
twoL = 2*L twoL = 2*L
d = L.shape[0] d = L.shape[0]
Id = sp.eye(d, format="csr") # Id = sp.eye(d, format="csr")
Id = _eyes_like(L, d)
if isinstance(T0, type(None)): if isinstance(T0, type(None)):
T0 = Id T0 = Id
T1 = sp.vstack((Id, L)) T1 = _vstack((Id, L))
rR = sp.hstack((-Id, twoL), format="csr") rR = _hstack((-1*Id, twoL))
if ret_gen: if ret_gen or isFaust(L):
g = _chebyshev_gen(L, T0, T1, rR, dev) g = _chebyshev_gen(L, T0, T1, rR, dev)
for i in range(0, K): for i in range(0, K):
next(g) next(g)
return next(g), g if ret_gen:
return next(g), g
else:
return next(g)
else: else:
return _chebyshev(L, K, T0, T1, rR, dev) return _chebyshev(L, K, T0, T1, rR, dev)
def _chebyshev(L, K, T0, T1, rR, dev='cpu'):
d = L.shape[0]
factors = [T0]
if(K > 0):
factors.insert(0, T1)
for i in range(2, K + 1):
Ti = _chebyshev_Ti_matrix(rR, L, i)
factors.insert(0, Ti)
T = Faust(factors, dev=dev)
return T # K-th poly is T[K*L.shape[0]:,:]
def _chebyshev_gen(L, T0, T1, rR, dev='cpu'):
T = Faust(T0)
yield T
T = Faust(T1) @ T
yield T
i = 2
while True:
Ti = _chebyshev_Ti_matrix(rR, L, i)
T = Faust(Ti) @ T
yield T
i += 1
def _chebyshev_Ti_matrix(rR, L, i):
d = L.shape[0]
if i <= 2:
R = rR
else:
zero = csr_matrix((d, (i-2)*d), dtype=float)
R = sp.hstack((zero, rR), format="csr")
Ti = sp.vstack((sp.eye(d*i, format="csr"), R),
format="csr")
return Ti
def basis(L, K, basis_name, ret_gen=False, dev='cpu', T0=None): def basis(L, K, basis_name, ret_gen=False, dev='cpu', T0=None):
""" """
Builds the Faust of the polynomial basis defined on the symmetric matrix L. Builds the Faust of the polynomial basis defined on the symmetric matrix L.
...@@ -98,7 +67,8 @@ def basis(L, K, basis_name, ret_gen=False, dev='cpu', T0=None): ...@@ -98,7 +67,8 @@ def basis(L, K, basis_name, ret_gen=False, dev='cpu', T0=None):
The Faust of the K+1 Chebyshev polynomials. The Faust of the K+1 Chebyshev polynomials.
""" """
if basis_name.lower() == 'chebyshev': if basis_name.lower() == 'chebyshev':
return Chebyshev(L, K, ret_gen=ret_gen, dev=dev, T0=None) return Chebyshev(L, K, ret_gen=ret_gen, dev=dev, T0=T0)
def poly(coeffs, L=None, basis=Chebyshev, dev='cpu'): def poly(coeffs, L=None, basis=Chebyshev, dev='cpu'):
""" """
...@@ -130,4 +100,118 @@ def poly(coeffs, L=None, basis=Chebyshev, dev='cpu'): ...@@ -130,4 +100,118 @@ def poly(coeffs, L=None, basis=Chebyshev, dev='cpu'):
format="csr") format="csr")
Fc = Faust(scoeffs, dev=dev) @ F Fc = Faust(scoeffs, dev=dev) @ F
return Fc return Fc
def _chebyshev(L, K, T0, T1, rR, dev='cpu'):
d = L.shape[0]
factors = [T0]
if(K > 0):
factors.insert(0, T1)
for i in range(2, K + 1):
Ti = _chebyshev_Ti_matrix(rR, L, i)
factors.insert(0, Ti)
T = Faust(factors, dev=dev)
return T # K-th poly is T[K*L.shape[0]:,:]
def _chebyshev_gen(L, T0, T1, rR, dev='cpu'):
if isFaust(T0):
T = T0
else:
T = Faust(T0)
yield T
if isFaust(T1):
T = T1 @ T
else:
T = Faust(T1) @ T
yield T
i = 2
while True:
Ti = _chebyshev_Ti_matrix(rR, L, i)
if isFaust(Ti):
T = Ti @ T
else:
T = Faust(Ti) @ T
yield T
i += 1
def _chebyshev_Ti_matrix(rR, L, i):
d = L.shape[0]
if i <= 2:
R = rR
else:
#zero = csr_matrix((d, (i-2)*d), dtype=float)
zero = _zeros_like(L, shape=(d, (i-2)*d))
R = _hstack((zero, rR))
di = d*i
Ti = _vstack((_eyes_like(L, shape=di), R))
return Ti
def _zeros_like(M, shape=None):
"""
Returns a zero of the same type of M: csr_matrix, pyfaust.Faust.
"""
if isinstance(shape, type(None)):
shape = M.shape
if isFaust(M):
zero = csr_matrix(([0], ([0], [0])), shape=shape)
return Faust(zero)
elif isinstance(M, csr_matrix):
zero = csr_matrix(shape, dtype=M.dtype)
return zero
else:
raise TypeError('M must be a Faust or a scipy.sparse.csr_matrix.')
def _eyes_like(M, shape=None):
"""
Returns an identity of the same type of M: csr_matrix, pyfaust.Faust.
"""
if isinstance(shape, type(None)):
shape = M.shape[1]
if isFaust(M):
return feye(shape)
elif isinstance(M, csr_matrix):
return sp.eye(shape, format='csr')
else:
raise TypeError('M must be a Faust or a scipy.sparse.csr_matrix.')
def _vstack(arrays):
_arrays = _build_consistent_tuple(arrays)
if isFaust(arrays[0]):
# all arrays are of type Faust
return fvstack(arrays)
else:
# all arrays are of type csr_matrix
return sp.vstack(arrays, format='csr')
def _hstack(arrays):
_arrays = _build_consistent_tuple(arrays)
if isFaust(arrays[0]):
# all arrays are of type Faust
return fhstack(arrays)
else:
# all arrays are of type csr_matrix
return sp.hstack(arrays, format='csr')
def _build_consistent_tuple(arrays):
contains_a_Faust = False
for a in arrays:
if isFaust(a):
contains_a_Faust = True
break
if contains_a_Faust:
_arrays = []
for a in arrays:
if not isFaust(a):
a = Faust(a)
_arrays.append(a)
return tuple(_arrays)
else:
return arrays
# experimental block end # experimental block end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment