Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 00c441eb authored by hhakim's avatar hhakim
Browse files

Add __array_ufunc__ to handle ufuncs functions.

__array_function__ and @implements decorator are not necessary but will probably be useful later with the numpy 1.16 protocol (optional in version 1.16 but enabled by default in 1.17).

https://numpy.org/doc/stable/reference/arrays.classes.html?highlight=array_function#numpy.class.__array_function__
parent 7d696ee4
No related branches found
No related tags found
No related merge requests found
......@@ -12,8 +12,11 @@ import pyfaust
import pyfaust.factparams
import warnings
import decimal
import numpy.lib.mixins
class Faust:
HANDLED_FUNCTIONS = {}
class Faust(numpy.lib.mixins.NDArrayOperatorsMixin):
"""<b>FAuST Python wrapper main class</b> for using multi-layer sparse transforms.
This class provides a numpy-like interface for operations
......@@ -166,6 +169,37 @@ class Faust:
else:
raise Exception("Cannot create an empty Faust.")
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
if method == '__call__':
if ufunc.__eq__("<ufunc 'matmul'>") and len(inputs) >= 2 and \
isFaust(inputs[1]):
return inputs[1].__rmatmul__(inputs[0])
if ufunc.__eq__("<ufunc 'multiply'>") and len(inputs) >= 2 and \
isFaust(inputs[1]):
return inputs[1].__rmul__(inputs[0])
N = None
fausts = []
elif method == 'reduce':
# # not necessary numpy calls Faust.sum
# if ufunc == "<ufunc 'add'>":
# if len(inputs) == 1 and pyfaust.isFaust(inputs[0]):
# #return inputs[0].sum(*inputs[1:], **kwargs)
# else:
return NotImplemented
def __array__(self, *args, **kwargs):
return self
def __array_function__(self, func, types, args, kwargs):
print("__array_function__")
if func not in HANDLED_FUNCTIONS:
return NotImplemented
# Note: this allows subclasses that don't override
# __array_function__ to handle MyArray objects
if not all(issubclass(t, Faust) for t in types):
return NotImplemented
return HANDLED_FUNCTIONS[func](*args, **kwargs)
@property
def nbytes(F):
"""
......@@ -447,6 +481,11 @@ class Faust:
_str = str(F.m_faust.to_string())
return _str
def __str__(F):
"""
"""
return F.__repr__()
def display(F):
"""
Displays information about F.
......@@ -757,7 +796,7 @@ class Faust:
else:
return F.m_faust.multiply(A)
def dot(F, A):
def dot(F, A, *args, **kwargs):
"""
Performs equivalent operation of numpy.dot() between the Faust F and A.
......@@ -868,9 +907,6 @@ class Faust:
except:
raise TypeError("invalid type operand for Faust.__rmul__.")
__array_ufunc__ = None # mandatory to override rmatmul
# it means Faust doesn't support ufuncs
def __rmatmul__(F, lhs_op):
"""
Returns lhs_op.__matmul__(F).
......@@ -1694,7 +1730,7 @@ class Faust:
else:
raise ValueError("complex -> float conversion not yet supported.")
def asarray(F):
def asarray(F, *args, **kwargs):
print('Faust.asarray')
#TODO: full list of numpy args or **kw_unknown
return F
......@@ -1928,6 +1964,10 @@ class Faust:
else:
return Faust([F.factors(i) for i in range(F.numfactors())], dev=dev)
def sum(F, axis=None, **kwargs):
return F@Faust(np.ones((F.shape[1], 1)))
def average(F, axis=None, weights=None, returned=False):
"""
......@@ -2001,6 +2041,14 @@ class Faust:
pyfaust.Faust.__div__ = pyfaust.Faust.__truediv__
def implements(numpy_function):
"""Register an __array_function__ implementation for MyArray
objects."""
def decorator(func):
HANDLED_FUNCTIONS[numpy_function] = func
return func
return decorator
def version():
"""Returns the FAuST package version.
"""
......@@ -2038,7 +2086,8 @@ def norm(F, ord='fro', **kwargs):
return np.linalg.norm(F, ord, axis=axis,
keepdims=keepdims)
def dot(A, B):
@implements(np.dot)
def dot(A, B, **kwargs):
"""
Returns Faust.dot(A,B) if A or B is a Faust object, returns numpy.dot(A,B) ortherwise.
......@@ -2064,7 +2113,9 @@ def pinv(F):
else:
return np.linalg.linalg.pinv(F)
def concatenate(_tuple, axis=0):
@implements(np.concatenate)
def concatenate(F, axis=0, **kwargs):
"""
A package function alias for the member function Faust.concatenate.
......@@ -2385,6 +2436,7 @@ def enable_gpu_mod(libpaths=None, backend='cuda', silent=False, fatal=False):
"""
_FaustCorePy.FaustCore.enable_gpu_mod(libpaths, backend, silent, fatal)
# experimental block start
import torch
......@@ -2597,3 +2649,4 @@ class FaustMulMode:
##
## Multiplying from the left to the right or in the way around in order to minimize the cost.
GPU_MOD=10
# -*- coding: utf-8 -*-
# @PYFAUST_LICENSE_HEADER@
from pyfaust import *
import numpy as np
import _FaustCorePy
import sys
if sys.version_info > (3,0):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment