Mentions légales du service

Skip to content
Snippets Groups Projects
Commit b8fc4bf9 authored by hhakim's avatar hhakim
Browse files

Update FaustTorch class (pure py.) exprimental code with a new method...

Update FaustTorch class (pure py.) exprimental code with a new method totensor() and refactor __mul__ method in this goal.
parent 782365a2
No related branches found
No related tags found
No related merge requests found
...@@ -2138,7 +2138,7 @@ class FaustTorch: ...@@ -2138,7 +2138,7 @@ class FaustTorch:
else: else:
raise err raise err
def __mul__(self, op, optimize_chain=False): def __mul__(self, op=None, optimize_chain=False, allow_none_op=False):
""" """
Multiplies the chain of tensors into a single tensor (matrix product). Multiplies the chain of tensors into a single tensor (matrix product).
...@@ -2147,20 +2147,43 @@ class FaustTorch: ...@@ -2147,20 +2147,43 @@ class FaustTorch:
res = torch.from_numpy(op) res = torch.from_numpy(op)
if(self.device != None): if(self.device != None):
res.to(self.device) res.to(self.device)
factors = self.factors
if(optimize_chain):
return self.mul_opt(res)
elif(allow_none_op and op == None):
factors = self.factors[:]
if(factors[-1].is_sparse):
factors.append(torch.from_numpy(np.eye(self.factors[-1].size()[1])))
if(self.device != None):
factors[-1].to(self.device)
res = factors[-1].clone()
factors = factors[:-1]
if(optimize_chain):
tmp = self.factors
self.factors = factors
res = self.mul_opt(res)
self.factors = tmp
return res
else: else:
raise TypeError('op must be a np.ndarray') raise TypeError('op must be a np.ndarray')
if(optimize_chain):
return self.mul_opt(res)
# torch matmul # torch matmul
#res = self.factors[0] #res = self.factors[0]
#for f in self.factors[1:]: #for f in self.factors[1:]:
for f in reversed(self.factors[:]): for f in reversed(factors[:]):
if(f.is_sparse): if(f.is_sparse):
res = torch.sparse.mm(f, res) res = torch.sparse.mm(f, res)
else: else:
res = torch.matmul(f, res) res = torch.matmul(f, res)
return res return res
def totensor(self, optimize_chain=False):
"""
See Faust.toarray()
"""
return self.__mul__(allow_none_op=True, optimize_chain=optimize_chain)
def mul_opt(self, op): def mul_opt(self, op):
""" """
Computes the product self*op optimizing the order of matrix chain products. Computes the product self*op optimizing the order of matrix chain products.
...@@ -2175,6 +2198,8 @@ class FaustTorch: ...@@ -2175,6 +2198,8 @@ class FaustTorch:
else: else:
b_cost = b.size()[1] b_cost = b.size()[1]
return a_cost*b_cost return a_cost*b_cost
factors = self.factors.copy() + [ op ] factors = self.factors.copy() + [ op ]
costs = [cost(factors[i], factors[i+1]) for i in range(len(factors)-1) costs = [cost(factors[i], factors[i+1]) for i in range(len(factors)-1)
] ]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment