Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 62ee20f1 authored by CARRIVAIN Pascal's avatar CARRIVAIN Pascal
Browse files

merge fast-deformable-butterfly review

parents 9d905a69 83a62380
No related branches found
No related tags found
1 merge request!1review of fdb function
Pipeline #1000739 failed
Subproject commit 163439ada9c3c0a74f9aa9057ae8e9e14dd21ba4
Subproject commit 47ad960169c5920c2b34da272171012f4a7850b8
......@@ -40,7 +40,8 @@ import warnings
from pyfaust.tools import _sanitize_dtype
from pyfaust.fdb import GB_factorization
from pyfaust.fdb import GB_operators
from pyfaust.fdb.GB_operators import twiddle_to_dense
from pyfaust.fdb.GB_param_generate import DebflyGen
# experimental block start
......@@ -1918,7 +1919,7 @@ def fdb(matrix: np.ndarray, n_factors: int=4,
hierarchical_order: str='left-to-right',
bit_rev_perm: bool=False,
backend: str='numpy'):
"""Return a FAuST object corresponding to the factorization of ``matrix``.
"""Return a Faust object corresponding to the factorization of ``matrix``.
Args:
matrix: ``np.ndarray``
......@@ -1934,6 +1935,7 @@ def fdb(matrix: np.ndarray, n_factors: int=4,
- 'left-to-right' (default)
- 'balanced'
bit_rev_perm: ``bool``, optional
Use bit reversal permutations matrix (default is ``False``).
It is useful when you would like to factorize DFT matrix.
......@@ -1944,8 +1946,8 @@ def fdb(matrix: np.ndarray, n_factors: int=4,
SVD and QR decompositions.
Returns:
A FAuST object that corresponds to the
factorization of the ``matrix``.
A Faust object that corresponds to the
factorization of ``matrix``.
Raises:
NotImplementedError
......@@ -1995,7 +1997,7 @@ def fdb(matrix: np.ndarray, n_factors: int=4,
matrix = matrix.reshape(1, 1, matrix_size[0], matrix_size[1])
# Set architecture for butterfly factorization.
test = GB_operators.DebflyGen(nrows, ncols, rank)
test = DebflyGen(nrows, ncols, rank)
m, min_param = test.smallest_monotone_debfly_chain(
n_factors, format="abcdpq"
)
......@@ -2015,7 +2017,7 @@ def fdb(matrix: np.ndarray, n_factors: int=4,
backend=backend)
for i, f in enumerate(factor_list):
# f.factor is either dense array (NumPy) or csr SciPy array.
tmp = GB_operators.twiddle_to_dense(f.factor, backend='numpy')
tmp = twiddle_to_dense(f.factor, backend='numpy')
shape = tmp.shape
# Build a list of factors.
# If more than 10% of non-zero elements use dense format,
......
......@@ -14,9 +14,9 @@ try:
except ImportError:
warn("Did not find einops, therefore use NumPy.")
found_einops = False
from pyfaust.fdb.utils import *
from pyfaust.fdb.GB_operators import *
from pyfaust.fdb.utils import partial_prod_deformable_butterfly_params, Factor
import numpy as np
import os
if "GB_DISABLE_EINOPS" in dict(os.environ).keys():
......
......@@ -14,9 +14,10 @@ try:
except ImportError:
found_einops = False
warn("Did not find einops, therefore use NumPy.")
from pyfaust.fdb.GB_param_generate import *
from pyfaust.fdb.utils import param_mul_param
# from pyfaust.fdb.GB_param_generate import *
# from pyfaust.fdb.utils import param_mul_param
import os
import numpy as np
if "GB_DISABLE_EINOPS" in dict(os.environ).keys():
......@@ -71,7 +72,7 @@ def twiddle_mul_twiddle(l_twiddle, r_twiddle, l_param,
if backend == 'pytorch' and found_pytorch:
result = torch.matmul(l_twiddle.float(), r_twiddle.float())
else:
result = l_twiddle.astype(np.float_) @ r_twiddle.astype(np.float_)
result = l_twiddle.astype(np.float64) @ r_twiddle.astype(np.float64)
if found_einops:
result = rearrange(
result, "(a c1) (b2 d) b1 c2 -> a d (b1 b2) (c1 c2)", c1=c1, b2=b2
......
import numpy as np
from warnings import warn
try:
from sympy import primerange
found_sympy = True
except ImportError:
found_sympy = False
warn(
"Did not find SymPy, therefore use"
+ "home-made ``primerange`` function."
)
try:
import torch
found_pytorch = True
except ImportError:
found_pytorch = False
warn("Did not find PyTorch, therefore use NumPy/SciPy.")
# from warnings import warn
# try:
# from sympy import primerange
# #
# found_sympy = True
# except ImportError:
# found_sympy = False
# warn("Did not find SymPy, therefore use" +
# " home-made ``primerange`` function.")
MAX = 1e18
......@@ -75,11 +65,12 @@ def check_compatibility(b, c, type):
elif type == "shrinking":
return b >= c
else:
raise Exception("type must be either 'square'," +
" 'expanding' or 'shrinking'.")
raise Exception(
"type must be either 'square'," + " 'expanding' or 'shrinking'."
)
def format_conversion(m, n, chainbc, weight, format: str="abcd"):
def format_conversion(m, n, chainbc, weight, format: str = "abcd"):
"""Return a sequence of deformable butterfly factors
using the infomation of b and c.
......@@ -117,8 +108,8 @@ def format_conversion(m, n, chainbc, weight, format: str="abcd"):
elif format == "abcdpq":
result.append((a, b, c, d, weight[i], weight[i + 1]))
else:
raise Exception("format must be either 'abcd'," +
" 'pqrst' or 'abcdpq'.")
raise Exception("format must be either 'abcd',"
+ " 'pqrst' or 'abcdpq'.")
a = a * c
return result
......@@ -233,20 +224,19 @@ class DebflyGen:
# self.m, self.n, chainbc, weight, format=format
# )
@staticmethod
def enumeration_inner_chain(m, n_factors):
if m == 1:
return [[1] * n_factors]
f_divisors, f_powers = list(factorize(m).items())[0]
results = []
for f1 in enumerate_Euler_sum(f_powers, n_factors):
for f2 in DebflyGen.enumeration_inner_chain(
m // (f_divisors**f_powers), n_factors
):
results.append(
[(f_divisors**a) * b for (a, b) in zip(f1, f2)]
)
return results
# @staticmethod
# def enumeration_inner_chain(m, n_factors):
# if m == 1:
# return [[1] * n_factors]
# f_divisors, f_powers = list(factorize(m).items())[0]
# results = []
# for f1 in enumerate_Euler_sum(f_powers, n_factors):
# for f2 in DebflyGen.enumeration_inner_chain(
# m // (f_divisors**f_powers), n_factors
# ):
# results.append([(f_divisors**a) * b
# for (a, b) in zip(f1, f2)])
# return results
# def enumeration_debfly_chain(self, n_factors, format="abcd"):
# results = []
......@@ -266,7 +256,7 @@ class DebflyGen:
# )
# return results
def smallest_monotone_debfly_chain(self, n_factors, format: str="abcd"):
def smallest_monotone_debfly_chain(self, n_factors, format: str = "abcd"):
"""Return a deformable butterfly chain whose product is of
size m x n has n_factors factors.
......@@ -319,16 +309,15 @@ class DebflyGen:
continue
if not check_compatibility(ii, jj, type):
continue
n_params_factor = (
i * jj * weight[k] * weight[k + 1]
)
n_params_fact = (i * jj * weight[k]
* weight[k + 1])
dp_tab = self.dp_table
if (
self.dp_table_temp[i, j]
> n_params_factor
+ jj * self.dp_table[i // ii, j // jj]
> n_params_fact + jj * dp_tab[i // ii, j // jj]
):
self.dp_table_temp[i, j] = (
n_params_factor
n_params_fact
+ jj * self.dp_table[i // ii, j // jj]
)
memorization[(i, j, k + 1)] = (ii, jj)
......
import numpy as np
from scipy.optimize import linear_sum_assignment
from scipy.spatial import distance
class Factor:
"""
......
......@@ -16,6 +16,14 @@ import unittest
class TestFact(unittest.TestCase):
"""Test case for the 'fact' module."""
def __init__(self, methodName='runTest', dev='cpu', dtype='double'):
super(TestFact, self).__init__(methodName)
self.dev = dev
if dtype == 'real': # backward compat
dtype = 'double'
self.dtype = dtype
# TODO: use different types later
def test_fdb(self):
"""Test of the function 'fdb'."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment