Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 52fd0b9c authored by hhakim's avatar hhakim
Browse files

Allow to try multiple permutations in butterfly and keep the best...

Allow to try multiple permutations in butterfly and keep the best corresponding approximation Faust + documentation and examples.
parent 567c6b8f
Branches
Tags
No related merge requests found
......@@ -28,8 +28,7 @@
import numpy as np, scipy
from scipy.io import loadmat
from scipy.sparse import csr_matrix, csc_matrix
import _FaustCorePy
from scipy.sparse import csr_matrix, csc_matrix, eye as seye, kron as skron
import pyfaust
import pyfaust.factparams
from pyfaust import Faust
......@@ -1338,20 +1337,42 @@ def butterfly(M, type="bbtree", perm=None):
factorization the most left factor (resp. the most right factor) is split in two.
If 'bbtree' is used then the matrix is factorized according to a Balanced
Binary Tree (which is faster as it allows parallelization).
perm: permutation column indices of the permutation P which is such that
the returned Faust F is the approximation of M@P and F@P.T the approximation of M.
perm: four kind of values are possible for this argument (Note that this argument is made only for the bbtree type of
factorization).
1. perm is the list of column indices of the permutation P which is such that
the returned Faust F is the approximation of M@P and F@P.T the
approximation of M.
2. perm is a list of list of permutation column indices as defined
in 1. In that case, all permutations passed to the function are
used as explained in 1, each one producing a Faust, the best one
(that is the best approximation of M) is kept and returned by butterfly.
3. perm is 'default_8', this is a particular case of 2. Eight
default permutations are used. For the definition of those
permutations please refer to [1].
4. By default this argument is None, no permutation is used.
Returns:
The Faust which is an approximattion of M according to a butterfly support.
Example:
>>> import numpy as np
>>> from random import randint
>>> from pyfaust import Faust, wht, dft
>>> from pyfaust.fact import butterfly
>>> H = wht(32).toarray() # it works with dft too!
>>> H = wht(8).toarray() # it works with dft too!
>>> F = butterfly(H, type='bbtree')
>>> (F-H).norm()/Faust(H).norm()
1.0272844187006565e-15
# use butterfly with a permutation factor defined by J
>>> J = np.arange(7, -1, -1)
>>> F = butterfly(H, type='bbtree', perm=J)
# use butterfly with successive permutation factors J1 and J2
# and keep the best approximation
>>> J1 = J
>>> from itertools import permutations
>>> permutations = list(permutations(J))
>>> J2 = list(permutations[randint(0, len(permutations)-1)])
>>> F = butterfly(H, type='bbtree', perm=[J1, J2])
Reference:
[1] Quoc-Tung Le, Léon Zheng, Elisa Riccietti, Rémi Gribonval. Fast
......@@ -1360,7 +1381,82 @@ def butterfly(M, type="bbtree", perm=None):
May 2022, Singapore, Singapore. (<a href="https://hal.inria.fr/hal-03438881">hal-03438881</a>)
"""
is_real = np.empty((1,))
if perm is not None and type != 'bbtree':
raise ValueError('perm argument is made only for type bbtree')
M = _check_fact_mat('butterfly()', M, is_real)
if isinstance(perm, str) and perm == 'default_8':
# the three modified functions below were originally extracted from the 3 clause-BSD code hosted here: https://github.com/leonzheng2/butterfly
# please look the header license here https://github.com/leonzheng2/butterfly/blob/main/src/utils.py
def perm_type(i, type):
"""
Type 0 is c in paper. Type 1 is b in paper. Type 2 is a in paper.
:param i:
:param type:
:return:
"""
size = 2 ** i
if type == 0:
row_inds = np.hstack((np.arange(size//2), size//2 + np.arange(size//2)))
col_inds = np.hstack((np.arange(size//2),size - 1 - np.arange(size//2)))
elif type == 1:
row_inds = np.hstack((size // 2 - 1 - np.arange(size//2), size // 2 + np.arange(size//2)))
col_inds = np.hstack((np.arange(size//2), size//2 + np.arange(size//2)))
else:
row_inds = np.hstack((np.arange(size//2), size//2 + np.arange(size//2)))
col_inds = np.hstack((np.arange(size//2) * 2, np.arange(size//2) * 2 + 1))
result = csr_matrix((np.ones(row_inds.size), (row_inds, col_inds)))
return result
def shared_logits_permutation(num_factors, choices):
"""
:param num_factors:
:param choices: array of three bool
:return:
"""
permutations = []
for i in range(2, num_factors + 1):
block = seye(2 ** i)
if choices[0]:
block = block @ perm_type(i, 0)
if choices[1]:
block = block @ perm_type(i, 1)
if choices[2]:
block = block @ perm_type(i, 2)
perm = skron(seye(2 ** (num_factors - i)), block)
permutations.append(perm)
return permutations
def get_permutation_matrix(num_factors, perm_name):
"""
:param num_factors:
:param perm_name: str, 000, 001, ..., 111
:return:
"""
if perm_name.isnumeric():
choices = [int(char) for char in perm_name]
p_list = shared_logits_permutation(num_factors, choices)
p = csr_matrix(Faust(p_list).toarray()) # TODO: keep csr along the whole product
else:
raise TypeError("perm_name must be numeric")
return p
permutations = [get_permutation_matrix(int(np.log2(M.shape[0])),
perm_name).indices \
for perm_name in ["000", "001", "010", "011", "100", "101", "110", "111"]]
return butterfly(M, type, permutations)
elif isinstance(perm, (list, tuple)) and isinstance(perm[0], (list, tuple,
np.ndarray)):
# loop on each perm and keep the best approximation
best_err = np.inf
best_F = None
for p in perm:
F = butterfly(M, type, p)
# compute error
error = (F-M).norm()/Faust(M).norm()
if error < best_err:
best_err = error
best_F = F
return best_F
args = (M, type, perm)
if is_real:
is_float = M.dtype == 'float32'
......
......@@ -583,7 +583,6 @@ cdef class FaustAlgoGen@TYPE_NAME@:
@staticmethod
def butterfly_hierarchical(M, dir, perm=None):
print("pyfaust.butterfly pyx")
cdef unsigned int M_num_rows=M.shape[0]
cdef unsigned int M_num_cols=M.shape[1]
cdef int[:] perm_view
......@@ -614,7 +613,7 @@ cdef class FaustAlgoGen@TYPE_NAME@:
if perm.ndim != 1:
raise ValueError('perm must be a vector')
else: # perm is a list
perm = np.ndarray(perm)
perm = np.array(perm)
if (perm >= M.shape[1]).any() or (perm < 0).any():
raise ValueError('perm contains an invalid index')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment