Mentions légales du service

Skip to content
Snippets Groups Projects
Commit a4b74b3c authored by CARRIVAIN Pascal's avatar CARRIVAIN Pascal
Browse files

add bit_rev_perm argument to fdb function

parent 9347725a
No related branches found
No related tags found
1 merge request!1review of fdb function
Pipeline #998179 failed
...@@ -1914,8 +1914,9 @@ def butterfly(M, type="bbtree", perm=None, diag_opt=False, mul_perm=None): ...@@ -1914,8 +1914,9 @@ def butterfly(M, type="bbtree", perm=None, diag_opt=False, mul_perm=None):
def fdb(matrix: np.ndarray, n_factors: int=4, def fdb(matrix: np.ndarray, n_factors: int=4,
rank: int=2, orthonormalize: bool=True, rank: int=1, orthonormalize: bool=True,
hierarchical_order: str='left-to-right', hierarchical_order: str='left-to-right',
bit_rev_perm: bool=False,
backend: str='numpy'): backend: str='numpy'):
"""Return a FAuST object corresponding to the factorization of ``matrix``. """Return a FAuST object corresponding to the factorization of ``matrix``.
Number of rows and number of columns of the matrix must be a power of two. Number of rows and number of columns of the matrix must be a power of two.
...@@ -1926,19 +1927,26 @@ def fdb(matrix: np.ndarray, n_factors: int=4, ...@@ -1926,19 +1927,26 @@ def fdb(matrix: np.ndarray, n_factors: int=4,
n_factors: ``int``, optional n_factors: ``int``, optional
Number of factors (4 is default). Number of factors (4 is default).
rank: ``int``, optional rank: ``int``, optional
Rank of sub-blocks (2 is default) Rank of sub-blocks (1 is default),
used by the underlying SVD.
orthonormalize: ``bool``, optional orthonormalize: ``bool``, optional
True (default) True (default)
hierarchical_order: ``str``, optional hierarchical_order: ``str``, optional
- 'left-to-right' (default) - 'left-to-right' (default)
- 'balanced' - 'balanced'
bit_rev_perm: ``bool``, optional
Use bit reversal permutations matrix (default is ``True``).
It is useful when you would like to factorize DFT matrix.
With no bit-reversal permutations you would have to
tune the value of the rank as a function of the matrix size.
backend: ``str``, optional backend: ``str``, optional
Use numpy (default) or pytorch to compute Use numpy (default) or pytorch to compute
SVD and QR decompositions. SVD and QR decompositions.
Returns: Returns:
A FAuST object that is the factorization of the ``matrix``. A FAuST object that corresponds to the
factorization of the ``matrix``.
Raises: Raises:
Exception: Exception:
...@@ -1985,7 +1993,13 @@ def fdb(matrix: np.ndarray, n_factors: int=4, ...@@ -1985,7 +1993,13 @@ def fdb(matrix: np.ndarray, n_factors: int=4,
if type(matrix).__name__ == 'Torch' and backend != 'pytorch': if type(matrix).__name__ == 'Torch' and backend != 'pytorch':
raise Exception("Because ``backend='pytorch'`` matrix" + raise Exception("Because ``backend='pytorch'`` matrix" +
" must be a PyTorch tensor.") " must be a PyTorch tensor.")
matrix = matrix.reshape(1, 1, matrix_size[0], matrix_size[1])
if bit_rev_perm:
P = bitrev_perm(matrix.shape[0])
# matrix = (P @ matrix).reshape(1, 1, matrix_size[0], matrix_size[1])
matrix = (matrix @ P.T).reshape(1, 1, matrix_size[0], matrix_size[1])
else:
matrix = matrix.reshape(1, 1, matrix_size[0], matrix_size[1])
# Set architecture for butterfly factorization. # Set architecture for butterfly factorization.
test = GB_operators.DebflyGen(nrows, ncols, rank) test = GB_operators.DebflyGen(nrows, ncols, rank)
...@@ -2017,7 +2031,10 @@ def fdb(matrix: np.ndarray, n_factors: int=4, ...@@ -2017,7 +2031,10 @@ def fdb(matrix: np.ndarray, n_factors: int=4,
tmp if (np.count_nonzero(tmp) / (shape[0] * shape[1])) > 0.1 tmp if (np.count_nonzero(tmp) / (shape[0] * shape[1])) > 0.1
else csr_matrix(tmp) else csr_matrix(tmp)
) )
if bit_rev_perm:
tmp = factor_list[0].dtype
# factor_list.insert(0, (P.astype(tmp)).T)
factor_list.append(P.astype(tmp))
return Faust(factor_list) return Faust(factor_list)
# return Faust([
# csr_matrix(GB_operators.twiddle_to_dense(f.factor,
# backend)) for f in factor_list])
import numpy as np import numpy as np
import os import os
from pyfaust import fact from pyfaust import fact
import warnings
try: try:
import torch import torch
found_pytorch = True found_pytorch = True
except ImportError: except ImportError:
found_pytorch = False found_pytorch = False
warn("Did not find PyTorch, therefore use NumPy/SciPy.") warnings.warn("Did not find PyTorch, therefore use NumPy/SciPy.")
import scipy as sp import scipy as sp
import unittest import unittest
...@@ -73,32 +74,12 @@ class TestFact(unittest.TestCase): ...@@ -73,32 +74,12 @@ class TestFact(unittest.TestCase):
self.assertTrue(error < 1e-12) self.assertTrue(error < 1e-12)
# Number of data points # Number of data points
N = 2 ** np.random.randint(1, high=7 + 1) N = 2 ** np.random.randint(1, high=8 + 1)
M = N + 1 M = N + 1
while M > N: while M > N:
M = 2 ** np.random.randint(1, high=7 + 1) M = 2 ** np.random.randint(1, high=8 + 1)
M = N M = N
if min(M, N) == 2: rank = 2 ** (int(np.log2(min(M, N))) // 2)
rank = 1
elif min(M, N) == 4:
rank = 2
elif min(M, N) == 8:
rank = 2
elif min(M, N) == 16:
rank = 4
elif min(M, N) == 32:
rank = 4
elif min(M, N) == 64:
rank = 8
elif min(M, N) == 128:
rank = 8
elif min(M, N) == 256:
rank = 16
elif min(M, N) == 512:
rank = 16
else:
print('here', M, N, min(M, N))
continue
x = np.exp(-2.0j * np.pi * np.arange(N) / N) x = np.exp(-2.0j * np.pi * np.arange(N) / N)
V = np.vander(x, N=None, increasing=True) V = np.vander(x, N=None, increasing=True)
print("vandermonde {0:d}: shape=({1:d}, {2:d}),".format(i, M, N) + print("vandermonde {0:d}: shape=({1:d}, {2:d}),".format(i, M, N) +
...@@ -106,16 +87,33 @@ class TestFact(unittest.TestCase): ...@@ -106,16 +87,33 @@ class TestFact(unittest.TestCase):
" number of factors={0:d}".format(n_factors)) " number of factors={0:d}".format(n_factors))
F = fact.fdb(V, n_factors=n_factors, rank=rank) F = fact.fdb(V, n_factors=n_factors, rank=rank)
print(F) print(F)
ncols = F.factors(F.numfactors() - 1).shape[1] nf = F.numfactors()
ncols = F.factors(nf - 1).shape[1]
approx = np.eye(ncols, M=N, k=0) approx = np.eye(ncols, M=N, k=0)
for f in range(F.numfactors()): for f in range(nf):
approx = F.factors(F.numfactors() - 1 - f) @ approx approx = F.factors(nf - 1 - f) @ approx
# Because of bit-reversal permutation we do not need
# to tune the rank as a function of matrix size.
F_bitrev = fact.fdb(V, n_factors=n_factors,
bit_rev_perm=True, rank=1)
print(F_bitrev)
nf = F_bitrev.numfactors()
ncols = F_bitrev.factors(nf - 1).shape[1]
approx_bitrev = np.eye(ncols, M=N, k=0)
for f in range(nf):
approx_bitrev = F_bitrev.factors(nf - 1 - f) @ approx_bitrev
# np.set_printoptions(linewidth=200)
# print(V)
# import pyfaust as pyf
# print(pyf.dft(M, normed=False).toarray())
# print(pyf.dft(M, normed=True).toarray())
# print(approx)
error = np.linalg.norm(V - approx) / np.linalg.norm(V) error = np.linalg.norm(V - approx) / np.linalg.norm(V)
# print(np.round(V, 3))
# print(np.round(approx, 3))
print('error={0:e}'.format(error)) print('error={0:e}'.format(error))
self.assertTrue(error < 1e-12) self.assertTrue(error < 1e-12)
error = np.linalg.norm(V - approx_bitrev) / np.linalg.norm(V)
print('error={0:e}'.format(error))
self.assertTrue(error < 1e-12)
# Cross-validation no PyTorch version versus PyTorch version. # Cross-validation no PyTorch version versus PyTorch version.
for i in range(n_repeats): for i in range(n_repeats):
...@@ -133,7 +131,8 @@ class TestFact(unittest.TestCase): ...@@ -133,7 +131,8 @@ class TestFact(unittest.TestCase):
" number of factors={0:d}".format(n_factors)) " number of factors={0:d}".format(n_factors))
# Use PyTorch # Use PyTorch
matrix0 = torch.randn(M, N) matrix0 = torch.randn(M, N)
F0 = fact.fdb(matrix0, n_factors=n_factors, rank=rank, backend='pytorch') F0 = fact.fdb(matrix0, n_factors=n_factors,
rank=rank, backend='pytorch')
print(F0) print(F0)
ncols = F0.factors(F0.numfactors() - 1).shape[1] ncols = F0.factors(F0.numfactors() - 1).shape[1]
approx0 = np.eye(ncols, M=N, k=0) approx0 = np.eye(ncols, M=N, k=0)
...@@ -141,7 +140,8 @@ class TestFact(unittest.TestCase): ...@@ -141,7 +140,8 @@ class TestFact(unittest.TestCase):
approx0 = F0.factors(F0.numfactors() - 1 - f) @ approx0 approx0 = F0.factors(F0.numfactors() - 1 - f) @ approx0
# Use NumPy # Use NumPy
matrix1 = matrix0.numpy() matrix1 = matrix0.numpy()
F1 = fact.fdb(matrix1, n_factors=n_factors, rank=rank, backend='numpy') F1 = fact.fdb(matrix1, n_factors=n_factors,
rank=rank, backend='numpy')
print(F1) print(F1)
ncols = F1.factors(F1.numfactors() - 1).shape[1] ncols = F1.factors(F1.numfactors() - 1).shape[1]
approx1 = np.eye(ncols, M=N, k=0) approx1 = np.eye(ncols, M=N, k=0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment