Mentions légales du service

Skip to content
Snippets Groups Projects
Commit ffc0d012 authored by hhakim's avatar hhakim
Browse files

Update pyfaust.Faust ctor to handle missing Faust creation by float array list...

Update pyfaust.Faust ctor to handle missing Faust creation by float array list on GPU and secure Faust dtype consistency.

- Raising an error if a factor is not in a supported dtype (float32, double or complex128).
- Faust.dtype relies on pyx dtype function that detects the dtype directly from the core object.
parent a2d748ee
Branches
Tags
No related merge requests found
......@@ -173,16 +173,21 @@ class Faust(numpy.lib.mixins.NDArrayOperatorsMixin):
dtype = f.dtype
elif dtype != f.dtype:
raise TypeError('All Faust factors must have the same dtype.')
if dtype not in ['double', 'float32', 'complex128']:
raise TypeError('Unmanaged factor dtype:'+str(dtype)+' (must be float32, double or complex128')
if(factors is not None and len(factors) > 0):
if(is_on_gpu):
if F._is_real:
F.m_faust = _FaustCorePy.FaustCoreGenDblGPU(factors, scale)
if dtype == 'double':
F.m_faust = _FaustCorePy.FaustCoreGenDblGPU(factors, scale)
else: # if dtype == 'float32':
F.m_faust = _FaustCorePy.FaustCoreGenFltGPU(factors, scale)
else:
F.m_faust = _FaustCorePy.FaustCoreGenCplxDblGPU(factors, scale)
elif F._is_real:
if dtype == 'double':
F.m_faust = _FaustCorePy.FaustCoreGenDblCPU(factors, scale)
elif dtype == 'float32':
else: # if dtype == 'float32':
F.m_faust = _FaustCorePy.FaustCoreGenFltCPU(factors, scale)
else:
F.m_faust = _FaustCorePy.FaustCoreGenCplxDblCPU(factors, scale)
......@@ -1846,10 +1851,7 @@ class Faust(numpy.lib.mixins.NDArrayOperatorsMixin):
"""
if(F.m_faust.isReal()):
return np.dtype(np.float64)
else:
return np.dtype(np.complex)
return F.m_faust.dtype()
def imshow(F, name='F'):
"""
......
def type2dtype(type):
return 'float32' if type == 'float' else type
return np.dtype('float32') if type == 'float' else np.dtype(type)
cdef class FaustCoreGen@TYPE_NAME@@PROC@:
......@@ -629,6 +629,9 @@ cdef class FaustCoreGen@TYPE_NAME@@PROC@:
@REAL_TYPE@].real((<@CORE_CLASS@?>self).@CORE_OBJ@)
return core
def dtype(self):
return type2dtype('@TYPE@')
def zpruneout(self, nnz_tres, npasses, only_forward):
core = @CORE_CLASS@(core=True)
core.@CORE_OBJ@ = self.@CORE_OBJ@.zpruneout(nnz_tres,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment