Mentions légales du service

Skip to content
Snippets Groups Projects
Commit fac21fa7 authored by hhakim's avatar hhakim
Browse files

Update pyfaust/matfaust.fact.butterfly: add 'bitrev' option for perm argument and complete the doc.

parent 85c82b31
Branches
Tags 3.31.0
No related merge requests found
Pipeline #834098 skipped
...@@ -7,12 +7,13 @@ ...@@ -7,12 +7,13 @@
%> Binary Tree (which is faster as it allows parallelization). %> Binary Tree (which is faster as it allows parallelization).
%> %>
%> %>
%> @param 'perm', value four kind of values are possible for this argument (Note that this argument is made only for the bbtree type of factorization). %> @param 'perm', value five kinds of values are possible for this argument.
%> %>
%> 1. perm is an array of column indices of the permutation matrix P which is such that the returned Faust is F = G * P.' where G is the Faust butterfly approximation of M*P. %> 1. perm is an array of column indices of the permutation matrix P which is such that the returned Faust is F = G * P where G is the Faust butterfly approximation of M*P.'. If the array of indices is not a valid permutation the behaviour is undefined (however an invalid size or an out of bound index raise an exception).
%> 2. perm is a cell array of arrays of permutation column indices as defined in 1. In that case, all permutations passed to the function are used as explained in 1, each one producing a Faust, the best one (that is the best approximation of M) is kept and returned by butterfly. %> 2. perm is a cell array of arrays of permutation column indices as defined in 1. In that case, all permutations passed to the function are used as explained in 1, each one producing a Faust, the best one (that is the best approximation of M) is kept and returned by butterfly.
%> 3. perm is 'default_8', this is a particular case of 2. Eight default permutations are used. For the definition of those permutations please refer to [1]. %> 3. perm is 'default_8', this is a particular case of 2. Eight default permutations are used. For the definition of those permutations please refer to [1].
%> 4. By default this argument is empty, no permutation is used. %> 4. perm is 'bitrev': in that case the permutation is the bit-reversal permutation (cf. matfaust.tools.bitrev_perm).
%> 5. By default this argument is empty, no permutation is used.
%> %>
%> @retval F the Faust which is an approximate of M according to a butterfly support. %> @retval F the Faust which is an approximate of M according to a butterfly support.
%> %>
...@@ -21,14 +22,27 @@ ...@@ -21,14 +22,27 @@
%> >> import matfaust.wht %> >> import matfaust.wht
%> >> import matfaust.dft %> >> import matfaust.dft
%> >> import matfaust.fact.butterfly %> >> import matfaust.fact.butterfly
%> >> M = full(wht(32)); % it works with dft too! %> >> H = full(wht(32)); % it works with dft too!
%> >> F = butterfly(M, 'type', 'bbtree'); %> >> F = butterfly(H, 'type', 'bbtree');
%> >> err = norm(full(F)-M)/norm(M) %> >> err = norm(full(F)-H)/norm(M)
%> err = %> err =
%> %>
%> 1.4311e-15 %> 1.4311e-15
%> @endcode %> @endcode
%> %>
%> Use butterfly with simple permutations:
%> @code
%> >> M = rand(4, 4);
%> >> % without any permutation
%> >> F1 = butterfly(M, 'type', 'bbtree');
%> >> % which is equivalent to identity permutation
%> >> p = 1:4;
%> >> F2 = butterfly(M, 'type', 'bbtree', 'perm', p);
%> >> % then try another permutation
%> >> p2 = [2, 1, 4, 3];
%> >> F3 = butterfly(M, 'type', 'bbtree', 'perm', p2);
%> @endcode
%>
%> Use butterfly with a permutation factor defined by J: %> Use butterfly with a permutation factor defined by J:
%> @code %> @code
%> >> J = 32:-1:1; %> >> J = 32:-1:1;
...@@ -41,7 +55,6 @@ ...@@ -41,7 +55,6 @@
%> - FACTOR 2 (double) SPARSE, size 32x32, density 0.0625, nnz 64 %> - FACTOR 2 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 3 (double) SPARSE, size 32x32, density 0.0625, nnz 64 %> - FACTOR 3 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 4 (double) SPARSE, size 32x32, density 0.0625, nnz 64 %> - FACTOR 4 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 5 (double) SPARSE, size 32x32, density 0.03125, nnz 32
%> @endcode %> @endcode
%> %>
%> Use butterfly with successive permutations J1 and J2 %> Use butterfly with successive permutations J1 and J2
...@@ -60,10 +73,22 @@ ...@@ -60,10 +73,22 @@
%> - FACTOR 2 (double) SPARSE, size 32x32, density 0.0625, nnz 64 %> - FACTOR 2 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 3 (double) SPARSE, size 32x32, density 0.0625, nnz 64 %> - FACTOR 3 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 4 (double) SPARSE, size 32x32, density 0.0625, nnz 64 %> - FACTOR 4 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 5 (double) SPARSE, size 32x32, density 0.03125, nnz 32
%> @endcode %> @endcode
%> %>
%> %>
%> Or to to use the 8 default permutations (keeping the best approximation resulting Faust)
%> @code
%> >> F = butterfly(H, 'type', 'bbtree', 'perm', 'default_8')
%> F =
%>
%> Faust size 32x32, density 0.3125, nnz_sum 320, 5 factor(s):
%> - FACTOR 0 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 1 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 2 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 3 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 4 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> @endcode
%>
%> %>
%> <b>Reference: [1]</b> Leon Zheng, Elisa Riccietti, and Remi Gribonval, <a href="https://arxiv.org/pdf/2110.01230.pdf">Hierarchical Identifiability in Multi-layer Sparse Matrix Factorization</a> %> <b>Reference: [1]</b> Leon Zheng, Elisa Riccietti, and Remi Gribonval, <a href="https://arxiv.org/pdf/2110.01230.pdf">Hierarchical Identifiability in Multi-layer Sparse Matrix Factorization</a>
%========================================================================== %==========================================================================
...@@ -82,8 +107,8 @@ function F = butterfly(M, varargin) ...@@ -82,8 +107,8 @@ function F = butterfly(M, varargin)
type = varargin{i+1}; type = varargin{i+1};
end end
case 'perm' case 'perm'
if(nargin < i+1 || ~ is_array_of_indices(varargin{i+1}, M) && ~ is_cell_arrays_of_indices(varargin{i+1}, M) && ~ strcmp(varargin{i+1}, 'default_8')) if(nargin < i+1 || ~ is_array_of_indices(varargin{i+1}, M) && ~ is_cell_arrays_of_indices(varargin{i+1}, M) && ~ strcmp(varargin{i+1}, 'default_8') && ~ strcmp(varargin{i+1}, 'bitrev'))
error('keyword argument ''perm'' must be followed by ''default_8'', an array of permutation indices or a cell array of arrays of permutation indices') error('keyword argument ''perm'' must be followed by ''default_8'', ''bitrev'', an array of permutation indices or a cell array of arrays of permutation indices')
else else
perm = varargin{i+1}; perm = varargin{i+1};
end end
...@@ -100,6 +125,12 @@ function F = butterfly(M, varargin) ...@@ -100,6 +125,12 @@ function F = butterfly(M, varargin)
end end
F = matfaust.fact.butterfly(M, 'type', type, 'perm', permutations); F = matfaust.fact.butterfly(M, 'type', type, 'perm', permutations);
return; return;
elseif strcmp(perm, 'bitrev')
P = bitrev_perm(size(M, 2));
[perm, ~, ~] = find(P.'); % cf. comments above
perm = perm.';
F = matfaust.fact.butterfly(M, 'type', type, 'perm', perm);
return;
elseif iscell(perm) % perm is a cell of arrays, each one defining a permutation to test elseif iscell(perm) % perm is a cell of arrays, each one defining a permutation to test
% evaluate butterfly factorisation using the permutations and % evaluate butterfly factorisation using the permutations and
% keep the best Faust % keep the best Faust
...@@ -109,7 +140,7 @@ function F = butterfly(M, varargin) ...@@ -109,7 +140,7 @@ function F = butterfly(M, varargin)
% perm{i} % perm{i}
m = numel(perm{i}); m = numel(perm{i});
F = matfaust.fact.butterfly(M, 'type', type, 'perm', perm{i}); F = matfaust.fact.butterfly(M, 'type', type, 'perm', perm{i});
err = norm(F-M, 'fro')/nM; err = norm(full(F)-M, 'fro')/nM;
if err < min_err if err < min_err
min_err = err; min_err = err;
best_F = F; best_F = F;
......
...@@ -32,6 +32,7 @@ from scipy.sparse import csr_matrix, csc_matrix, eye as seye, kron as skron ...@@ -32,6 +32,7 @@ from scipy.sparse import csr_matrix, csc_matrix, eye as seye, kron as skron
import pyfaust import pyfaust
import pyfaust.factparams import pyfaust.factparams
from pyfaust import Faust from pyfaust import Faust
from pyfaust.tools import bitrev_perm
import _FaustCorePy import _FaustCorePy
import warnings import warnings
...@@ -1337,10 +1338,13 @@ def butterfly(M, type="bbtree", perm=None): ...@@ -1337,10 +1338,13 @@ def butterfly(M, type="bbtree", perm=None):
factorization the most left factor (resp. the most right factor) is split in two. factorization the most left factor (resp. the most right factor) is split in two.
If 'bbtree' is used then the matrix is factorized according to a Balanced If 'bbtree' is used then the matrix is factorized according to a Balanced
Binary Tree (which is faster as it allows parallelization). Binary Tree (which is faster as it allows parallelization).
perm: four kind of values are possible for this argument (Note that this argument is made only for the bbtree type of perm: five kinds of values are possible for this argument.
factorization).
1. perm is a list of column indices of the permutation matrix P which is such that 1. perm is a list of column indices of the permutation matrix P which is such that
the returned Faust is F = G@P.T where G is the Faust butterfly approximation of M@P. the returned Faust is F = G@P where G is the Faust butterfly
approximation of M@P.T.
If the list of indices is not a valid permutation the behaviour
is undefined (however an invalid size or an out of bound index raise
an exception).
2. perm is a list of lists of permutation column indices as defined 2. perm is a list of lists of permutation column indices as defined
in 1. In that case, all permutations passed to the function are in 1. In that case, all permutations passed to the function are
used as explained in 1, each one producing a Faust, the best one used as explained in 1, each one producing a Faust, the best one
...@@ -1348,7 +1352,9 @@ def butterfly(M, type="bbtree", perm=None): ...@@ -1348,7 +1352,9 @@ def butterfly(M, type="bbtree", perm=None):
3. perm is 'default_8', this is a particular case of 2. Eight 3. perm is 'default_8', this is a particular case of 2. Eight
default permutations are used. For the definition of those default permutations are used. For the definition of those
permutations please refer to [1]. permutations please refer to [1].
4. By default this argument is None, no permutation is used. 4. perm is 'bitrev': in that case the permutation is the
bit-reversal permutation (cf. pyfaust.tools.bitrev_perm).
5. By default this argument is None, no permutation is used.
Returns: Returns:
The Faust which is an approximattion of M according to a butterfly support. The Faust which is an approximattion of M according to a butterfly support.
...@@ -1356,13 +1362,28 @@ def butterfly(M, type="bbtree", perm=None): ...@@ -1356,13 +1362,28 @@ def butterfly(M, type="bbtree", perm=None):
Example: Example:
>>> import numpy as np >>> import numpy as np
>>> from random import randint >>> from random import randint
>>> from pyfaust import Faust, wht, dft
>>> from pyfaust.fact import butterfly >>> from pyfaust.fact import butterfly
>>> from pyfaust import Faust, wht, dft
>>> H = wht(8).toarray() # it works with dft too! >>> H = wht(8).toarray() # it works with dft too!
>>> F = butterfly(H, type='bbtree') >>> F = butterfly(H, type='bbtree')
>>> (F-H).norm()/Faust(H).norm() >>> (F-H).norm()/Faust(H).norm()
1.0272844187006565e-15 1.0272844187006565e-15
# use butterfly with a permutation factor defined by J
Use simple permutations:
>>> import numpy as np
>>> from random import randint
>>> from pyfaust.fact import butterfly
>>> M = np.random.rand(4, 4)
>>> # without any permutation
>>> F1 = butterfly(M, type='bbtree')
>>> # which is equivalent to identity permutation
>>> p = np.arange(0, 4)
>>> F2 = butterfly(M, type='bbtree', perm=p)
>>> # then try another permutation
>>> p2 = [1, 0, 3, 2]
>>> F3 = butterfly(M, type='bbtree', perm=p2)
Use butterfly with a permutation factor defined by J:
>>> J = np.arange(7, -1, -1) >>> J = np.arange(7, -1, -1)
>>> F = butterfly(H, type='bbtree', perm=J) >>> F = butterfly(H, type='bbtree', perm=J)
# use butterfly with successive permutations J1 and J2 # use butterfly with successive permutations J1 and J2
...@@ -1381,74 +1402,77 @@ def butterfly(M, type="bbtree", perm=None): ...@@ -1381,74 +1402,77 @@ def butterfly(M, type="bbtree", perm=None):
International Conference on Acoustics, Speech and Signal Processing, International Conference on Acoustics, Speech and Signal Processing,
May 2022, Singapore, Singapore. (<a href="https://hal.inria.fr/hal-03438881">hal-03438881</a>) May 2022, Singapore, Singapore. (<a href="https://hal.inria.fr/hal-03438881">hal-03438881</a>)
""" """
from pyfaust.tools import bitrev_perm
is_real = np.empty((1,)) is_real = np.empty((1,))
if perm is not None and type != 'bbtree':
raise ValueError('perm argument is made only for type bbtree')
M = _check_fact_mat('butterfly()', M, is_real) M = _check_fact_mat('butterfly()', M, is_real)
if isinstance(perm, str) and perm == 'default_8': if isinstance(perm, str):
# the three modified functions below were originally extracted from the 3 clause-BSD code hosted here: https://github.com/leonzheng2/butterfly if perm == 'bitrev':
# please look the header license here https://github.com/leonzheng2/butterfly/blob/main/src/utils.py P = bitrev_perm(M.shape[1])
def perm_type(i, type): return butterfly(M, type, perm=P.indices)
""" elif perm == 'default_8':
Type 0 is c in paper. Type 1 is b in paper. Type 2 is a in paper. # the three modified functions below were originally extracted from the 3 clause-BSD code hosted here: https://github.com/leonzheng2/butterfly
:param i: # please look the header license here https://github.com/leonzheng2/butterfly/blob/main/src/utils.py
:param type: def perm_type(i, type):
:return: """
""" Type 0 is c in paper. Type 1 is b in paper. Type 2 is a in paper.
size = 2 ** i :param i:
if type == 0: :param type:
row_inds = np.hstack((np.arange(size//2), size//2 + np.arange(size//2))) :return:
col_inds = np.hstack((np.arange(size//2), size - 1 - np.arange(size//2))) """
elif type == 1: size = 2 ** i
row_inds = np.hstack((size // 2 - 1 - np.arange(size//2), size // 2 + np.arange(size//2))) if type == 0:
col_inds = np.hstack((np.arange(size//2), size//2 + np.arange(size//2))) row_inds = np.hstack((np.arange(size//2), size//2 + np.arange(size//2)))
else: col_inds = np.hstack((np.arange(size//2), size - 1 - np.arange(size//2)))
row_inds = np.hstack((np.arange(size//2), size//2 + np.arange(size//2))) elif type == 1:
col_inds = np.hstack((np.arange(size//2) * 2, np.arange(size//2) * 2 + 1)) row_inds = np.hstack((size // 2 - 1 - np.arange(size//2), size // 2 + np.arange(size//2)))
result = csr_matrix((np.ones(row_inds.size), (row_inds, col_inds))) col_inds = np.hstack((np.arange(size//2), size//2 + np.arange(size//2)))
return result else:
row_inds = np.hstack((np.arange(size//2), size//2 + np.arange(size//2)))
def shared_logits_permutation(num_factors, choices): col_inds = np.hstack((np.arange(size//2) * 2, np.arange(size//2) * 2 + 1))
""" result = csr_matrix((np.ones(row_inds.size), (row_inds, col_inds)))
:param num_factors: return result
:param choices: array of three bool
:return: def shared_logits_permutation(num_factors, choices):
""" """
permutations = [] :param num_factors:
for i in range(2, num_factors + 1): :param choices: array of three bool
block = seye(2 ** i) :return:
if choices[0]: """
block = block @ perm_type(i, 0) permutations = []
if choices[1]: for i in range(2, num_factors + 1):
block = block @ perm_type(i, 1) block = seye(2 ** i)
if choices[2]: if choices[0]:
block = block @ perm_type(i, 2) block = block @ perm_type(i, 0)
perm = skron(seye(2 ** (num_factors - i)), block) if choices[1]:
permutations.append(perm) block = block @ perm_type(i, 1)
return permutations if choices[2]:
block = block @ perm_type(i, 2)
def get_permutation_matrix(num_factors, perm_name): perm = skron(seye(2 ** (num_factors - i)), block)
""" permutations.append(perm)
:param num_factors: return permutations
:param perm_name: str, 000, 001, ..., 111
:return: def get_permutation_matrix(num_factors, perm_name):
""" """
if perm_name.isnumeric(): :param num_factors:
choices = [int(char) for char in perm_name] :param perm_name: str, 000, 001, ..., 111
p_list = shared_logits_permutation(num_factors, choices) :return:
p = csr_matrix(Faust(p_list).toarray()) # TODO: keep csr along the whole product """
else: if perm_name.isnumeric():
raise TypeError("perm_name must be numeric") choices = [int(char) for char in perm_name]
return p p_list = shared_logits_permutation(num_factors, choices)
p = csr_matrix(Faust(p_list).toarray()) # TODO: keep csr along the whole product
# print(list(get_permutation_matrix(int(np.log2(M.shape[0])), else:
# perm_name).indices+1 \ raise TypeError("perm_name must be numeric")
# for perm_name in ["000", "001", "010", "011", "100", return p
# "101", "110", "111"]))
permutations = [get_permutation_matrix(int(np.log2(M.shape[0])), # print(list(get_permutation_matrix(int(np.log2(M.shape[0])),
perm_name).indices \ # perm_name).indices+1 \
for perm_name in ["000", "001", "010", "011", "100", "101", "110", "111"]] # for perm_name in ["000", "001", "010", "011", "100",
return butterfly(M, type, permutations) # "101", "110", "111"]))
permutations = [get_permutation_matrix(int(np.log2(M.shape[0])),
perm_name).indices \
for perm_name in ["000", "001", "010", "011", "100", "101", "110", "111"]]
return butterfly(M, type, permutations)
elif isinstance(perm, (list, tuple)) and isinstance(perm[0], (list, tuple, elif isinstance(perm, (list, tuple)) and isinstance(perm[0], (list, tuple,
np.ndarray)): np.ndarray)):
# loop on each perm and keep the best approximation # loop on each perm and keep the best approximation
...@@ -1460,7 +1484,7 @@ def butterfly(M, type="bbtree", perm=None): ...@@ -1460,7 +1484,7 @@ def butterfly(M, type="bbtree", perm=None):
P = csr_matrix((np.ones(row_inds.size), (row_inds, p))) P = csr_matrix((np.ones(row_inds.size), (row_inds, p)))
F = butterfly(M, type, p) F = butterfly(M, type, p)
# compute error # compute error
error = np.linalg.norm(F-M)/Faust(M).norm() error = np.linalg.norm(F.toarray()-M)/Faust(M).norm()
# print(error) # print(error)
if error < best_err: if error < best_err:
best_err = error best_err = error
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment