Mentions légales du service

Skip to content
Snippets Groups Projects
Commit a4b74b3c authored by CARRIVAIN Pascal's avatar CARRIVAIN Pascal
Browse files

add bit_rev_perm argument to fdb function

parent 9347725a
Branches
Tags
1 merge request!1review of fdb function
Pipeline #998179 failed
......@@ -1914,8 +1914,9 @@ def butterfly(M, type="bbtree", perm=None, diag_opt=False, mul_perm=None):
def fdb(matrix: np.ndarray, n_factors: int=4,
rank: int=2, orthonormalize: bool=True,
rank: int=1, orthonormalize: bool=True,
hierarchical_order: str='left-to-right',
bit_rev_perm: bool=False,
backend: str='numpy'):
"""Return a FAuST object corresponding to the factorization of ``matrix``.
Number of rows and number of columns of the matrix must be a power of two.
......@@ -1926,19 +1927,26 @@ def fdb(matrix: np.ndarray, n_factors: int=4,
n_factors: ``int``, optional
Number of factors (4 is default).
rank: ``int``, optional
Rank of sub-blocks (2 is default)
Rank of sub-blocks (1 is default),
used by the underlying SVD.
orthonormalize: ``bool``, optional
True (default)
hierarchical_order: ``str``, optional
- 'left-to-right' (default)
- 'balanced'
bit_rev_perm: ``bool``, optional
Use bit reversal permutations matrix (default is ``True``).
It is useful when you would like to factorize DFT matrix.
With no bit-reversal permutations you would have to
tune the value of the rank as a function of the matrix size.
backend: ``str``, optional
Use numpy (default) or pytorch to compute
SVD and QR decompositions.
Returns:
A FAuST object that is the factorization of the ``matrix``.
A FAuST object that corresponds to the
factorization of the ``matrix``.
Raises:
Exception:
......@@ -1985,7 +1993,13 @@ def fdb(matrix: np.ndarray, n_factors: int=4,
if type(matrix).__name__ == 'Torch' and backend != 'pytorch':
raise Exception("Because ``backend='pytorch'`` matrix" +
" must be a PyTorch tensor.")
matrix = matrix.reshape(1, 1, matrix_size[0], matrix_size[1])
if bit_rev_perm:
P = bitrev_perm(matrix.shape[0])
# matrix = (P @ matrix).reshape(1, 1, matrix_size[0], matrix_size[1])
matrix = (matrix @ P.T).reshape(1, 1, matrix_size[0], matrix_size[1])
else:
matrix = matrix.reshape(1, 1, matrix_size[0], matrix_size[1])
# Set architecture for butterfly factorization.
test = GB_operators.DebflyGen(nrows, ncols, rank)
......@@ -2017,7 +2031,10 @@ def fdb(matrix: np.ndarray, n_factors: int=4,
tmp if (np.count_nonzero(tmp) / (shape[0] * shape[1])) > 0.1
else csr_matrix(tmp)
)
if bit_rev_perm:
tmp = factor_list[0].dtype
# factor_list.insert(0, (P.astype(tmp)).T)
factor_list.append(P.astype(tmp))
return Faust(factor_list)
# return Faust([
# csr_matrix(GB_operators.twiddle_to_dense(f.factor,
# backend)) for f in factor_list])
import numpy as np
import os
from pyfaust import fact
import warnings
try:
import torch
found_pytorch = True
except ImportError:
found_pytorch = False
warn("Did not find PyTorch, therefore use NumPy/SciPy.")
warnings.warn("Did not find PyTorch, therefore use NumPy/SciPy.")
import scipy as sp
import unittest
......@@ -73,32 +74,12 @@ class TestFact(unittest.TestCase):
self.assertTrue(error < 1e-12)
# Number of data points
N = 2 ** np.random.randint(1, high=7 + 1)
N = 2 ** np.random.randint(1, high=8 + 1)
M = N + 1
while M > N:
M = 2 ** np.random.randint(1, high=7 + 1)
M = 2 ** np.random.randint(1, high=8 + 1)
M = N
if min(M, N) == 2:
rank = 1
elif min(M, N) == 4:
rank = 2
elif min(M, N) == 8:
rank = 2
elif min(M, N) == 16:
rank = 4
elif min(M, N) == 32:
rank = 4
elif min(M, N) == 64:
rank = 8
elif min(M, N) == 128:
rank = 8
elif min(M, N) == 256:
rank = 16
elif min(M, N) == 512:
rank = 16
else:
print('here', M, N, min(M, N))
continue
rank = 2 ** (int(np.log2(min(M, N))) // 2)
x = np.exp(-2.0j * np.pi * np.arange(N) / N)
V = np.vander(x, N=None, increasing=True)
print("vandermonde {0:d}: shape=({1:d}, {2:d}),".format(i, M, N) +
......@@ -106,16 +87,33 @@ class TestFact(unittest.TestCase):
" number of factors={0:d}".format(n_factors))
F = fact.fdb(V, n_factors=n_factors, rank=rank)
print(F)
ncols = F.factors(F.numfactors() - 1).shape[1]
nf = F.numfactors()
ncols = F.factors(nf - 1).shape[1]
approx = np.eye(ncols, M=N, k=0)
for f in range(F.numfactors()):
approx = F.factors(F.numfactors() - 1 - f) @ approx
for f in range(nf):
approx = F.factors(nf - 1 - f) @ approx
# Because of bit-reversal permutation we do not need
# to tune the rank as a function of matrix size.
F_bitrev = fact.fdb(V, n_factors=n_factors,
bit_rev_perm=True, rank=1)
print(F_bitrev)
nf = F_bitrev.numfactors()
ncols = F_bitrev.factors(nf - 1).shape[1]
approx_bitrev = np.eye(ncols, M=N, k=0)
for f in range(nf):
approx_bitrev = F_bitrev.factors(nf - 1 - f) @ approx_bitrev
# np.set_printoptions(linewidth=200)
# print(V)
# import pyfaust as pyf
# print(pyf.dft(M, normed=False).toarray())
# print(pyf.dft(M, normed=True).toarray())
# print(approx)
error = np.linalg.norm(V - approx) / np.linalg.norm(V)
# print(np.round(V, 3))
# print(np.round(approx, 3))
print('error={0:e}'.format(error))
self.assertTrue(error < 1e-12)
error = np.linalg.norm(V - approx_bitrev) / np.linalg.norm(V)
print('error={0:e}'.format(error))
self.assertTrue(error < 1e-12)
# Cross-validation no PyTorch version versus PyTorch version.
for i in range(n_repeats):
......@@ -133,7 +131,8 @@ class TestFact(unittest.TestCase):
" number of factors={0:d}".format(n_factors))
# Use PyTorch
matrix0 = torch.randn(M, N)
F0 = fact.fdb(matrix0, n_factors=n_factors, rank=rank, backend='pytorch')
F0 = fact.fdb(matrix0, n_factors=n_factors,
rank=rank, backend='pytorch')
print(F0)
ncols = F0.factors(F0.numfactors() - 1).shape[1]
approx0 = np.eye(ncols, M=N, k=0)
......@@ -141,7 +140,8 @@ class TestFact(unittest.TestCase):
approx0 = F0.factors(F0.numfactors() - 1 - f) @ approx0
# Use NumPy
matrix1 = matrix0.numpy()
F1 = fact.fdb(matrix1, n_factors=n_factors, rank=rank, backend='numpy')
F1 = fact.fdb(matrix1, n_factors=n_factors,
rank=rank, backend='numpy')
print(F1)
ncols = F1.factors(F1.numfactors() - 1).shape[1]
approx1 = np.eye(ncols, M=N, k=0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment