Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 8be6abe9 authored by hhakim's avatar hhakim
Browse files

Apply fixes in pyfaust.poly.

- Fix pyfaust.poly.Chebyshev(Faust(L), K, impl=py) wrong dimension of basis
- Fix pyfaust.poly._zeros dtype when M a Faust.
- (Chebyshev) Raise an error when L is a Faust but impl='native' (it must be 'py' as stated in the doc).
- Minor changes in doc and formatting.
parent 4110f65e
Branches
Tags
No related merge requests found
......@@ -24,8 +24,8 @@ def Chebyshev(L, K, dev='cpu', T0=None, impl='native'):
Builds the Faust of the Chebyshev polynomial basis defined on the sparse matrix L.
Args:
L: (scipy.sparse.csr_matrix or L)
the sparse square matrix or square Faust (in this case impl must be
L: (scipy.sparse.csr_matrix or Faust)
the sparse square matrix or square Faust (in the last case impl must be
'py')
K: (int)
the degree of the last polynomial, i.e. the K+1 first polynomials are built.
......@@ -48,24 +48,27 @@ def Chebyshev(L, K, dev='cpu', T0=None, impl='native'):
if L.shape[0] != L.shape[1]:
raise ValueError('L must be a square matrix.')
if impl == "py":
twoL = 2*L
twoL = 2 * L
d = L.shape[0]
# Id = sp.eye(d, format="csr")
Id = _eyes_like(L, d)
if isinstance(T0, type(None)):
if T0 is None:
T0 = Id
T1 = _vstack((Id, L))
rR = _hstack((-1*Id, twoL))
rR = _hstack((-1 * Id, twoL))
if isFaust(L):
if isFaust(T0):
T0 = T0.factors(0)
G = FaustPoly(T0, T1=T1, rR=rR, L=L, dev=dev, impl='py')
for i in range(0, K):
for i in range(0, K-1):
next(G)
return next(G)
else:
return _chebyshev(L, K, T0, T1, rR, dev)
elif impl == 'native':
if isFaust(L):
raise TypeError("L cannot be a Faust if impl is 'native', try"
" impl='py'")
if L.dtype == 'complex':
F = FaustPoly(core_obj=_FaustCorePy.FaustAlgoGenCplxDbl.polyBasis(L, K, T0,
dev.startswith('gpu')),
......@@ -156,7 +159,7 @@ def basis(L, K, basis_name, dev='cpu', T0=None, **kwargs):
"""
# impl (optional): 'native' (by default) for the C++ impl., "py" for the Python impl.
# L can aslo be a Faust if impl is "py".
# L can also be a Faust if impl is "py".
impl = 'native'
if 'impl' in kwargs:
if kwargs['impl'] in ['py', 'native']:
......@@ -412,7 +415,7 @@ def _zeros_like(M, shape=None):
if isinstance(shape, type(None)):
shape = M.shape
if isFaust(M):
zero = csr_matrix(([0], ([0], [0])), shape=shape)
zero = csr_matrix(([0], ([0], [0])), shape=shape, dtype=M.dtype)
return Faust(zero)
elif isinstance(M, csr_matrix):
zero = csr_matrix(shape, dtype=M.dtype)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment