Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 71b0f2ab authored by hhakim's avatar hhakim
Browse files

Optimize pyfaust.__add__ (all the arguments are added with the Faust in almost...

Optimize pyfaust.__add__ (all the arguments are added with the Faust in almost one C++ call instead of repeating the concatenation/addition for each one).

Issue #35.
parent 752cb38d
No related branches found
No related tags found
No related merge requests found
......@@ -523,7 +523,7 @@ class Faust(numpy.lib.mixins.NDArrayOperatorsMixin):
print(F.__repr__())
#F.m_faust.display()
def __add__(F, *args):
def __add__(F, *args, **kwargs):
"""
Sums F to one or a sequence of variables. Faust objects, arrays or scalars.
......@@ -547,22 +547,93 @@ class Faust(numpy.lib.mixins.NDArrayOperatorsMixin):
See also Faust.__sub__
"""
if 'opt' not in kwargs or kwargs['opt']:
return F._add_opt(*args)
else:
return F._add_nonopt(*args)
def _add_opt(F, *args):
array_types = (np.ndarray,
scipy.sparse.csr_matrix,
scipy.sparse.csc_matrix)
dev = F.device
brd_err = ("operands could not be broadcast"
" together with shapes ")
# handle possible broadcasting of F
# if dimensions are inconsistent it'll fail later
if(F.shape[1] == 1):
max_ncols = np.max([a.shape[1] for a in args if isinstance(a,
(Faust,
*array_types))])
F = F@Faust(np.ones(1, max_ncols))
if(F.shape[0] == 1):
max_nrows = np.max([a.shape[0] for a in args if isinstance(a,
(Faust,
*array_types))])
F = Faust(np.ones(max_nrows, 1))@F
def scalar2Faust(G):
if isinstance(G, int):
G = float(G)
elif not isinstance(G, (np.float, np.complex)):
raise TypeError("scalar must be int, float or complex")
G, Gdtype = (float(G), np.float) if isinstance(G, np.float) else (G,
np.complex)
return Faust([np.ones((F.shape[0], 1))*G,
np.ones((1, F.shape[1])).astype(Gdtype)])
def broadcast_to_F(G):
if G.shape[0] == 1:
if G.shape[1] != F.shape[1]:
raise ve
G = Faust(np.ones((F.shape[0], 1))) @ G
elif G.shape[1] == 1:
if G.shape[0] != F.shape[0]:
raise ve
G = G @ Faust(np.ones((1, F.shape[1])))
return G
# prepare the list of Fausts
largs = []
for i in range(0,len(args)):
G = args[i]
if isinstance(G,Faust):
ve = ValueError(brd_err, F.shape,
G.shape, " argument i=", i)
G = broadcast_to_F(G)
if F.shape != G.shape:
raise Exception('Dimensions must agree, argument i=', i)
elif isinstance(G,
array_types):
if(G.size == 1):
G = scalar2Faust(G.reshape(1,)[0])
if G.ndim == 1:
G = Faust([np.ones((F.shape[0], 1)), G.reshape(1, G.size)])
else:
G = broadcast_to_F(Faust(G))
elif isinstance(G,(int, float, np.complex)):
G = scalar2Faust(G)
largs.append(G)
id = np.eye(F.shape[1])
id_vstack = np.vstack([id for i in range(0,
len(largs)+1)])
C = F.concatenate(*largs, axis=1)
F = C@Faust(id_vstack)
return F
def _add_nonopt(F, *args):
dev = F.device
for i in range(0,len(args)):
G = args[i]
if isinstance(G,Faust):
ve = ValueError("operands could not be broadcast"
" together with shapes ", F.shape,
G.shape)
if G.shape[0] == 1:
if G.shape[1] != F.shape[1]:
raise ValueError("operands could not be broadcast"
" together with shapes ", F.shape,
G.shape)
raise ve
G = Faust(np.ones((F.shape[0], 1))) @ G
elif G.shape[1] == 1:
if G.shape[0] != F.shape[0]:
raise ValueError("operands could not be broadcast"
" together with shapes ", F.shape,
G.shape)
raise ve
G = G @ Faust(np.ones((1, F.shape[1])))
if F.shape != G.shape:
raise Exception('Dimensions must agree')
......@@ -570,16 +641,17 @@ class Faust(numpy.lib.mixins.NDArrayOperatorsMixin):
# Id = np.eye(int(C.shape[1]/2))
Fid = eye(int(C.shape[1]/2), dev=dev)
#F = C*Faust(np.concatenate((Id,Id)),axis=0)
F = C*Fid.concatenate(Fid,axis=0)
elif (isinstance(G,np.ndarray) or
isinstance(G,scipy.sparse.csr_matrix) or
isinstance(G,scipy.sparse.csc_matrix)):
F = C@Fid.concatenate(Fid,axis=0)
elif isinstance(G,
(np.ndarray,
scipy.sparse.csr_matrix,
scipy.sparse.csc_matrix)):
if(G.size == 1):
if G.dtype == np.complex:
return F*(np.complex(G.squeeze()))
F = F+(np.complex(G.squeeze()))
else:
return F*(float(G.squeeze()))
if G.ndim == 1:
F = F+(float(G.squeeze()))
elif G.ndim == 1:
G = Faust([np.ones((F.shape[0], 1)), G.reshape(1, G.size)])
if(F.shape[1] == 1):
F = F@Faust(np.ones((1,F.shape[0])))
......@@ -1032,26 +1104,28 @@ class Faust(numpy.lib.mixins.NDArrayOperatorsMixin):
if(axis not in [0,1]):
raise ValueError("Axis must be 0 or 1.")
largs = []
for i,G in enumerate(args):
if(isinstance(G, (np.ndarray, csr_matrix, csc_matrix))):
args[i] = Faust([G])
G = Faust(G)
if(not isinstance(G, Faust)): raise ValueError("You can't concatenate a "
"Faust with something "
"that is not a Faust, "
"a numpy array or scipy "
"sparse matrix.")
largs.append(G)
if(axis == 0 and F.shape[1] != G.shape[1] or axis == 1 and F.shape[0]
!= G.shape[0]): raise ValueError("The dimensions of "
"the two Fausts must "
"agree.")
if all([isFaust(G) for G in args]) and not "iterative" in kwargs.keys() or kwargs['iterative']:
if all([isFaust(G) for G in largs]) and not "iterative" in kwargs.keys() or kwargs['iterative']:
# use iterative meth.
if axis == 0:
C = Faust(core_obj=F.m_faust.vertcatn(*[G.m_faust for G in args]))
C = Faust(core_obj=F.m_faust.vertcatn(*[G.m_faust for G in largs]))
else: # axis == 1
C = Faust(core_obj=F.m_faust.horzcatn(*[G.m_faust for G in args]))
C = Faust(core_obj=F.m_faust.horzcatn(*[G.m_faust for G in largs]))
return C
# use recursive meth.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment