Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 62ee20f1 authored by CARRIVAIN Pascal's avatar CARRIVAIN Pascal
Browse files

merge fast-deformable-butterfly review

parents 9d905a69 83a62380
No related branches found
No related tags found
1 merge request!1review of fdb function
Pipeline #1000739 failed
Subproject commit 163439ada9c3c0a74f9aa9057ae8e9e14dd21ba4 Subproject commit 47ad960169c5920c2b34da272171012f4a7850b8
...@@ -40,7 +40,8 @@ import warnings ...@@ -40,7 +40,8 @@ import warnings
from pyfaust.tools import _sanitize_dtype from pyfaust.tools import _sanitize_dtype
from pyfaust.fdb import GB_factorization from pyfaust.fdb import GB_factorization
from pyfaust.fdb import GB_operators from pyfaust.fdb.GB_operators import twiddle_to_dense
from pyfaust.fdb.GB_param_generate import DebflyGen
# experimental block start # experimental block start
...@@ -1918,7 +1919,7 @@ def fdb(matrix: np.ndarray, n_factors: int=4, ...@@ -1918,7 +1919,7 @@ def fdb(matrix: np.ndarray, n_factors: int=4,
hierarchical_order: str='left-to-right', hierarchical_order: str='left-to-right',
bit_rev_perm: bool=False, bit_rev_perm: bool=False,
backend: str='numpy'): backend: str='numpy'):
"""Return a FAuST object corresponding to the factorization of ``matrix``. """Return a Faust object corresponding to the factorization of ``matrix``.
Args: Args:
matrix: ``np.ndarray`` matrix: ``np.ndarray``
...@@ -1934,6 +1935,7 @@ def fdb(matrix: np.ndarray, n_factors: int=4, ...@@ -1934,6 +1935,7 @@ def fdb(matrix: np.ndarray, n_factors: int=4,
- 'left-to-right' (default) - 'left-to-right' (default)
- 'balanced' - 'balanced'
bit_rev_perm: ``bool``, optional bit_rev_perm: ``bool``, optional
Use bit reversal permutations matrix (default is ``False``). Use bit reversal permutations matrix (default is ``False``).
It is useful when you would like to factorize DFT matrix. It is useful when you would like to factorize DFT matrix.
...@@ -1944,8 +1946,8 @@ def fdb(matrix: np.ndarray, n_factors: int=4, ...@@ -1944,8 +1946,8 @@ def fdb(matrix: np.ndarray, n_factors: int=4,
SVD and QR decompositions. SVD and QR decompositions.
Returns: Returns:
A FAuST object that corresponds to the A Faust object that corresponds to the
factorization of the ``matrix``. factorization of ``matrix``.
Raises: Raises:
NotImplementedError NotImplementedError
...@@ -1995,7 +1997,7 @@ def fdb(matrix: np.ndarray, n_factors: int=4, ...@@ -1995,7 +1997,7 @@ def fdb(matrix: np.ndarray, n_factors: int=4,
matrix = matrix.reshape(1, 1, matrix_size[0], matrix_size[1]) matrix = matrix.reshape(1, 1, matrix_size[0], matrix_size[1])
# Set architecture for butterfly factorization. # Set architecture for butterfly factorization.
test = GB_operators.DebflyGen(nrows, ncols, rank) test = DebflyGen(nrows, ncols, rank)
m, min_param = test.smallest_monotone_debfly_chain( m, min_param = test.smallest_monotone_debfly_chain(
n_factors, format="abcdpq" n_factors, format="abcdpq"
) )
...@@ -2015,7 +2017,7 @@ def fdb(matrix: np.ndarray, n_factors: int=4, ...@@ -2015,7 +2017,7 @@ def fdb(matrix: np.ndarray, n_factors: int=4,
backend=backend) backend=backend)
for i, f in enumerate(factor_list): for i, f in enumerate(factor_list):
# f.factor is either dense array (NumPy) or csr SciPy array. # f.factor is either dense array (NumPy) or csr SciPy array.
tmp = GB_operators.twiddle_to_dense(f.factor, backend='numpy') tmp = twiddle_to_dense(f.factor, backend='numpy')
shape = tmp.shape shape = tmp.shape
# Build a list of factors. # Build a list of factors.
# If more than 10% of non-zero elements use dense format, # If more than 10% of non-zero elements use dense format,
......
...@@ -14,9 +14,9 @@ try: ...@@ -14,9 +14,9 @@ try:
except ImportError: except ImportError:
warn("Did not find einops, therefore use NumPy.") warn("Did not find einops, therefore use NumPy.")
found_einops = False found_einops = False
from pyfaust.fdb.utils import * from pyfaust.fdb.utils import partial_prod_deformable_butterfly_params, Factor
from pyfaust.fdb.GB_operators import *
import numpy as np import numpy as np
import os
if "GB_DISABLE_EINOPS" in dict(os.environ).keys(): if "GB_DISABLE_EINOPS" in dict(os.environ).keys():
......
...@@ -14,9 +14,10 @@ try: ...@@ -14,9 +14,10 @@ try:
except ImportError: except ImportError:
found_einops = False found_einops = False
warn("Did not find einops, therefore use NumPy.") warn("Did not find einops, therefore use NumPy.")
from pyfaust.fdb.GB_param_generate import * # from pyfaust.fdb.GB_param_generate import *
from pyfaust.fdb.utils import param_mul_param # from pyfaust.fdb.utils import param_mul_param
import os import os
import numpy as np
if "GB_DISABLE_EINOPS" in dict(os.environ).keys(): if "GB_DISABLE_EINOPS" in dict(os.environ).keys():
...@@ -71,7 +72,7 @@ def twiddle_mul_twiddle(l_twiddle, r_twiddle, l_param, ...@@ -71,7 +72,7 @@ def twiddle_mul_twiddle(l_twiddle, r_twiddle, l_param,
if backend == 'pytorch' and found_pytorch: if backend == 'pytorch' and found_pytorch:
result = torch.matmul(l_twiddle.float(), r_twiddle.float()) result = torch.matmul(l_twiddle.float(), r_twiddle.float())
else: else:
result = l_twiddle.astype(np.float_) @ r_twiddle.astype(np.float_) result = l_twiddle.astype(np.float64) @ r_twiddle.astype(np.float64)
if found_einops: if found_einops:
result = rearrange( result = rearrange(
result, "(a c1) (b2 d) b1 c2 -> a d (b1 b2) (c1 c2)", c1=c1, b2=b2 result, "(a c1) (b2 d) b1 c2 -> a d (b1 b2) (c1 c2)", c1=c1, b2=b2
......
import numpy as np import numpy as np
from warnings import warn # from warnings import warn
try:
from sympy import primerange
found_sympy = True
except ImportError:
found_sympy = False
warn(
"Did not find SymPy, therefore use"
+ "home-made ``primerange`` function."
)
try:
import torch
found_pytorch = True
except ImportError:
found_pytorch = False
warn("Did not find PyTorch, therefore use NumPy/SciPy.")
# try:
# from sympy import primerange
# #
# found_sympy = True
# except ImportError:
# found_sympy = False
# warn("Did not find SymPy, therefore use" +
# " home-made ``primerange`` function.")
MAX = 1e18 MAX = 1e18
...@@ -75,11 +65,12 @@ def check_compatibility(b, c, type): ...@@ -75,11 +65,12 @@ def check_compatibility(b, c, type):
elif type == "shrinking": elif type == "shrinking":
return b >= c return b >= c
else: else:
raise Exception("type must be either 'square'," + raise Exception(
" 'expanding' or 'shrinking'.") "type must be either 'square'," + " 'expanding' or 'shrinking'."
)
def format_conversion(m, n, chainbc, weight, format: str="abcd"): def format_conversion(m, n, chainbc, weight, format: str = "abcd"):
"""Return a sequence of deformable butterfly factors """Return a sequence of deformable butterfly factors
using the infomation of b and c. using the infomation of b and c.
...@@ -117,8 +108,8 @@ def format_conversion(m, n, chainbc, weight, format: str="abcd"): ...@@ -117,8 +108,8 @@ def format_conversion(m, n, chainbc, weight, format: str="abcd"):
elif format == "abcdpq": elif format == "abcdpq":
result.append((a, b, c, d, weight[i], weight[i + 1])) result.append((a, b, c, d, weight[i], weight[i + 1]))
else: else:
raise Exception("format must be either 'abcd'," + raise Exception("format must be either 'abcd',"
" 'pqrst' or 'abcdpq'.") + " 'pqrst' or 'abcdpq'.")
a = a * c a = a * c
return result return result
...@@ -233,20 +224,19 @@ class DebflyGen: ...@@ -233,20 +224,19 @@ class DebflyGen:
# self.m, self.n, chainbc, weight, format=format # self.m, self.n, chainbc, weight, format=format
# ) # )
@staticmethod # @staticmethod
def enumeration_inner_chain(m, n_factors): # def enumeration_inner_chain(m, n_factors):
if m == 1: # if m == 1:
return [[1] * n_factors] # return [[1] * n_factors]
f_divisors, f_powers = list(factorize(m).items())[0] # f_divisors, f_powers = list(factorize(m).items())[0]
results = [] # results = []
for f1 in enumerate_Euler_sum(f_powers, n_factors): # for f1 in enumerate_Euler_sum(f_powers, n_factors):
for f2 in DebflyGen.enumeration_inner_chain( # for f2 in DebflyGen.enumeration_inner_chain(
m // (f_divisors**f_powers), n_factors # m // (f_divisors**f_powers), n_factors
): # ):
results.append( # results.append([(f_divisors**a) * b
[(f_divisors**a) * b for (a, b) in zip(f1, f2)] # for (a, b) in zip(f1, f2)])
) # return results
return results
# def enumeration_debfly_chain(self, n_factors, format="abcd"): # def enumeration_debfly_chain(self, n_factors, format="abcd"):
# results = [] # results = []
...@@ -266,7 +256,7 @@ class DebflyGen: ...@@ -266,7 +256,7 @@ class DebflyGen:
# ) # )
# return results # return results
def smallest_monotone_debfly_chain(self, n_factors, format: str="abcd"): def smallest_monotone_debfly_chain(self, n_factors, format: str = "abcd"):
"""Return a deformable butterfly chain whose product is of """Return a deformable butterfly chain whose product is of
size m x n has n_factors factors. size m x n has n_factors factors.
...@@ -319,16 +309,15 @@ class DebflyGen: ...@@ -319,16 +309,15 @@ class DebflyGen:
continue continue
if not check_compatibility(ii, jj, type): if not check_compatibility(ii, jj, type):
continue continue
n_params_factor = ( n_params_fact = (i * jj * weight[k]
i * jj * weight[k] * weight[k + 1] * weight[k + 1])
) dp_tab = self.dp_table
if ( if (
self.dp_table_temp[i, j] self.dp_table_temp[i, j]
> n_params_factor > n_params_fact + jj * dp_tab[i // ii, j // jj]
+ jj * self.dp_table[i // ii, j // jj]
): ):
self.dp_table_temp[i, j] = ( self.dp_table_temp[i, j] = (
n_params_factor n_params_fact
+ jj * self.dp_table[i // ii, j // jj] + jj * self.dp_table[i // ii, j // jj]
) )
memorization[(i, j, k + 1)] = (ii, jj) memorization[(i, j, k + 1)] = (ii, jj)
......
import numpy as np
from scipy.optimize import linear_sum_assignment
from scipy.spatial import distance
class Factor: class Factor:
""" """
......
...@@ -16,6 +16,14 @@ import unittest ...@@ -16,6 +16,14 @@ import unittest
class TestFact(unittest.TestCase): class TestFact(unittest.TestCase):
"""Test case for the 'fact' module.""" """Test case for the 'fact' module."""
def __init__(self, methodName='runTest', dev='cpu', dtype='double'):
super(TestFact, self).__init__(methodName)
self.dev = dev
if dtype == 'real': # backward compat
dtype = 'double'
self.dtype = dtype
# TODO: use different types later
def test_fdb(self): def test_fdb(self):
"""Test of the function 'fdb'.""" """Test of the function 'fdb'."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment