Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 6ab93b63 authored by hhakim's avatar hhakim
Browse files

Rewrite pyfaust Faust.real/imag functions to obtain a real dtype and set these...

Rewrite pyfaust Faust.real/imag functions to obtain a real dtype and set these functions as properties (update DST and DST consequently).
parent 584ad61d
Branches
Tags
No related merge requests found
......@@ -2255,11 +2255,12 @@ class Faust(numpy.lib.mixins.NDArrayOperatorsMixin):
" available in FAµST")
else:
# dtype is double
return F.real()
return F.real
else:
return Faust([F.factors(i).astype(dtype) for i in
range(F.numfactors())], dev=F.device)
@property
def real(F):
"""
Returns the real part of F as a Faust.
......@@ -2267,16 +2268,28 @@ class Faust(numpy.lib.mixins.NDArrayOperatorsMixin):
if F.dtype != np.complex:
return F
else:
return 1/2 * (F + F.conj())
# return 1/2 * (F + F.conj())
return (_cplx2real_op(F)[:F.shape[0],
:2*F.shape[1]]@Faust(svstack((seye(F.shape[1]),
csr_matrix((F.shape[1],
F.shape[1])))))).pruneout()
#return 1/2 * (F + F.conj())
@property
def imag(F):
"""
Returns the imaginary part of F as a Faust.
"""
if F.dtype != np.complex:
return F
# return Faust(csr_matrix(F.shape)) # TODO: debug pyx code
return Faust(csr_matrix((np.array([0.]).astype(F.dtype),
([0],[0])), (F.shape)))
else:
return 1/2j * (F + F.conj())
# return 1/2j * (F + F.conj())
return (_cplx2real_op(F)[F.shape[0]:2*F.shape[0],
:2*F.shape[1]]@Faust(svstack((seye(F.shape[1]),
csr_matrix((F.shape[1],
F.shape[1])))))).pruneout()
def asarray(F, *args, **kwargs):
return F
......@@ -3332,7 +3345,7 @@ def dct(n, dev='cpu'):
mid_F = mid_factors
else:
mid_F = Faust(mid_factors)
DCT = (Faust(f0) @ mid_F @ Faust(f_end)).real()
DCT = (Faust(f0) @ mid_F @ Faust(f_end)).real
return DCT
# experimental block start
......@@ -3382,7 +3395,7 @@ def dst3(n, dev='cpu'):
F_odd = SD1 @ Faust(D2) @ DFT
F = pyfaust.hstack((F_even, F_odd))
F = F @ Faust(P_)
return F.real()
return F.real
# experimental block end
def dst(n, dev='cpu'):
......@@ -3482,7 +3495,7 @@ def dst(n, dev='cpu'):
F_odd = Faust(D1, dev=dev) @ Faust(D2) @ MDFT
F = pyfaust.hstack((F_even, F_odd))
F = F @ Faust(P_, dev=dev)
return F.real()
return F.real
def circ(c, **kwargs):
"""Returns a circulant Faust C defined by the vector c (which is the first column of C.toarray()).
......@@ -4010,6 +4023,24 @@ def check_dev(dev):
elif dev != 'cpu':
raise ValueError("dev must be 'cpu' or 'gpu[:id]'")
def _cplx2real_op(op):
if pyfaust.isFaust(op):
return Faust([_cplx2real_op(op.factors(i)) for i in range(op.numfactors())])
else:
rop = np.real(op)
iop = np.imag(op)
if isinstance(op, (csr_matrix, csc_matrix, coo_matrix, bsr_matrix)):
vertcat = svstack
horzcat = shstack
elif isinstance(op, np.ndarray):
vertcat = vstack
horzcat = hstack
else:
raise TypeError('op must be a scipy sparse matrix or a np.ndarray')
real_part = horzcat((rop, - iop))
imag_part = horzcat((iop, rop))
return vertcat((real_part, imag_part))
# experimental block start
# @PYTORCH_EXP_CODE@
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment