Mentions légales du service

Skip to content
Snippets Groups Projects
Commit fac21fa7 authored by hhakim's avatar hhakim
Browse files

Update pyfaust/matfaust.fact.butterfly: add 'bitrev' option for perm argument and complete the doc.

parent 85c82b31
No related branches found
No related tags found
No related merge requests found
Pipeline #834098 skipped
......@@ -7,12 +7,13 @@
%> Binary Tree (which is faster as it allows parallelization).
%>
%>
%> @param 'perm', value four kind of values are possible for this argument (Note that this argument is made only for the bbtree type of factorization).
%> @param 'perm', value five kinds of values are possible for this argument.
%>
%> 1. perm is an array of column indices of the permutation matrix P which is such that the returned Faust is F = G * P.' where G is the Faust butterfly approximation of M*P.
%> 1. perm is an array of column indices of the permutation matrix P which is such that the returned Faust is F = G * P where G is the Faust butterfly approximation of M*P.'. If the array of indices is not a valid permutation the behaviour is undefined (however an invalid size or an out of bound index raise an exception).
%> 2. perm is a cell array of arrays of permutation column indices as defined in 1. In that case, all permutations passed to the function are used as explained in 1, each one producing a Faust, the best one (that is the best approximation of M) is kept and returned by butterfly.
%> 3. perm is 'default_8', this is a particular case of 2. Eight default permutations are used. For the definition of those permutations please refer to [1].
%> 4. By default this argument is empty, no permutation is used.
%> 4. perm is 'bitrev': in that case the permutation is the bit-reversal permutation (cf. matfaust.tools.bitrev_perm).
%> 5. By default this argument is empty, no permutation is used.
%>
%> @retval F the Faust which is an approximate of M according to a butterfly support.
%>
......@@ -21,14 +22,27 @@
%> >> import matfaust.wht
%> >> import matfaust.dft
%> >> import matfaust.fact.butterfly
%> >> M = full(wht(32)); % it works with dft too!
%> >> F = butterfly(M, 'type', 'bbtree');
%> >> err = norm(full(F)-M)/norm(M)
%> >> H = full(wht(32)); % it works with dft too!
%> >> F = butterfly(H, 'type', 'bbtree');
%> >> err = norm(full(F)-H)/norm(M)
%> err =
%>
%> 1.4311e-15
%> @endcode
%>
%> Use butterfly with simple permutations:
%> @code
%> >> M = rand(4, 4);
%> >> % without any permutation
%> >> F1 = butterfly(M, 'type', 'bbtree');
%> >> % which is equivalent to identity permutation
%> >> p = 1:4;
%> >> F2 = butterfly(M, 'type', 'bbtree', 'perm', p);
%> >> % then try another permutation
%> >> p2 = [2, 1, 4, 3];
%> >> F3 = butterfly(M, 'type', 'bbtree', 'perm', p2);
%> @endcode
%>
%> Use butterfly with a permutation factor defined by J:
%> @code
%> >> J = 32:-1:1;
......@@ -41,7 +55,6 @@
%> - FACTOR 2 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 3 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 4 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 5 (double) SPARSE, size 32x32, density 0.03125, nnz 32
%> @endcode
%>
%> Use butterfly with successive permutations J1 and J2
......@@ -60,10 +73,22 @@
%> - FACTOR 2 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 3 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 4 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 5 (double) SPARSE, size 32x32, density 0.03125, nnz 32
%> @endcode
%>
%>
%> Or to to use the 8 default permutations (keeping the best approximation resulting Faust)
%> @code
%> >> F = butterfly(H, 'type', 'bbtree', 'perm', 'default_8')
%> F =
%>
%> Faust size 32x32, density 0.3125, nnz_sum 320, 5 factor(s):
%> - FACTOR 0 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 1 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 2 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 3 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 4 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> @endcode
%>
%>
%> <b>Reference: [1]</b> Leon Zheng, Elisa Riccietti, and Remi Gribonval, <a href="https://arxiv.org/pdf/2110.01230.pdf">Hierarchical Identifiability in Multi-layer Sparse Matrix Factorization</a>
%==========================================================================
......@@ -82,8 +107,8 @@ function F = butterfly(M, varargin)
type = varargin{i+1};
end
case 'perm'
if(nargin < i+1 || ~ is_array_of_indices(varargin{i+1}, M) && ~ is_cell_arrays_of_indices(varargin{i+1}, M) && ~ strcmp(varargin{i+1}, 'default_8'))
error('keyword argument ''perm'' must be followed by ''default_8'', an array of permutation indices or a cell array of arrays of permutation indices')
if(nargin < i+1 || ~ is_array_of_indices(varargin{i+1}, M) && ~ is_cell_arrays_of_indices(varargin{i+1}, M) && ~ strcmp(varargin{i+1}, 'default_8') && ~ strcmp(varargin{i+1}, 'bitrev'))
error('keyword argument ''perm'' must be followed by ''default_8'', ''bitrev'', an array of permutation indices or a cell array of arrays of permutation indices')
else
perm = varargin{i+1};
end
......@@ -100,6 +125,12 @@ function F = butterfly(M, varargin)
end
F = matfaust.fact.butterfly(M, 'type', type, 'perm', permutations);
return;
elseif strcmp(perm, 'bitrev')
P = bitrev_perm(size(M, 2));
[perm, ~, ~] = find(P.'); % cf. comments above
perm = perm.';
F = matfaust.fact.butterfly(M, 'type', type, 'perm', perm);
return;
elseif iscell(perm) % perm is a cell of arrays, each one defining a permutation to test
% evaluate butterfly factorisation using the permutations and
% keep the best Faust
......@@ -109,7 +140,7 @@ function F = butterfly(M, varargin)
% perm{i}
m = numel(perm{i});
F = matfaust.fact.butterfly(M, 'type', type, 'perm', perm{i});
err = norm(F-M, 'fro')/nM;
err = norm(full(F)-M, 'fro')/nM;
if err < min_err
min_err = err;
best_F = F;
......
......@@ -32,6 +32,7 @@ from scipy.sparse import csr_matrix, csc_matrix, eye as seye, kron as skron
import pyfaust
import pyfaust.factparams
from pyfaust import Faust
from pyfaust.tools import bitrev_perm
import _FaustCorePy
import warnings
......@@ -1337,10 +1338,13 @@ def butterfly(M, type="bbtree", perm=None):
factorization the most left factor (resp. the most right factor) is split in two.
If 'bbtree' is used then the matrix is factorized according to a Balanced
Binary Tree (which is faster as it allows parallelization).
perm: four kind of values are possible for this argument (Note that this argument is made only for the bbtree type of
factorization).
perm: five kinds of values are possible for this argument.
1. perm is a list of column indices of the permutation matrix P which is such that
the returned Faust is F = G@P.T where G is the Faust butterfly approximation of M@P.
the returned Faust is F = G@P where G is the Faust butterfly
approximation of M@P.T.
If the list of indices is not a valid permutation the behaviour
is undefined (however an invalid size or an out of bound index raise
an exception).
2. perm is a list of lists of permutation column indices as defined
in 1. In that case, all permutations passed to the function are
used as explained in 1, each one producing a Faust, the best one
......@@ -1348,7 +1352,9 @@ def butterfly(M, type="bbtree", perm=None):
3. perm is 'default_8', this is a particular case of 2. Eight
default permutations are used. For the definition of those
permutations please refer to [1].
4. By default this argument is None, no permutation is used.
4. perm is 'bitrev': in that case the permutation is the
bit-reversal permutation (cf. pyfaust.tools.bitrev_perm).
5. By default this argument is None, no permutation is used.
Returns:
The Faust which is an approximattion of M according to a butterfly support.
......@@ -1356,13 +1362,28 @@ def butterfly(M, type="bbtree", perm=None):
Example:
>>> import numpy as np
>>> from random import randint
>>> from pyfaust import Faust, wht, dft
>>> from pyfaust.fact import butterfly
>>> from pyfaust import Faust, wht, dft
>>> H = wht(8).toarray() # it works with dft too!
>>> F = butterfly(H, type='bbtree')
>>> (F-H).norm()/Faust(H).norm()
1.0272844187006565e-15
# use butterfly with a permutation factor defined by J
Use simple permutations:
>>> import numpy as np
>>> from random import randint
>>> from pyfaust.fact import butterfly
>>> M = np.random.rand(4, 4)
>>> # without any permutation
>>> F1 = butterfly(M, type='bbtree')
>>> # which is equivalent to identity permutation
>>> p = np.arange(0, 4)
>>> F2 = butterfly(M, type='bbtree', perm=p)
>>> # then try another permutation
>>> p2 = [1, 0, 3, 2]
>>> F3 = butterfly(M, type='bbtree', perm=p2)
Use butterfly with a permutation factor defined by J:
>>> J = np.arange(7, -1, -1)
>>> F = butterfly(H, type='bbtree', perm=J)
# use butterfly with successive permutations J1 and J2
......@@ -1381,74 +1402,77 @@ def butterfly(M, type="bbtree", perm=None):
International Conference on Acoustics, Speech and Signal Processing,
May 2022, Singapore, Singapore. (<a href="https://hal.inria.fr/hal-03438881">hal-03438881</a>)
"""
from pyfaust.tools import bitrev_perm
is_real = np.empty((1,))
if perm is not None and type != 'bbtree':
raise ValueError('perm argument is made only for type bbtree')
M = _check_fact_mat('butterfly()', M, is_real)
if isinstance(perm, str) and perm == 'default_8':
# the three modified functions below were originally extracted from the 3 clause-BSD code hosted here: https://github.com/leonzheng2/butterfly
# please look the header license here https://github.com/leonzheng2/butterfly/blob/main/src/utils.py
def perm_type(i, type):
"""
Type 0 is c in paper. Type 1 is b in paper. Type 2 is a in paper.
:param i:
:param type:
:return:
"""
size = 2 ** i
if type == 0:
row_inds = np.hstack((np.arange(size//2), size//2 + np.arange(size//2)))
col_inds = np.hstack((np.arange(size//2), size - 1 - np.arange(size//2)))
elif type == 1:
row_inds = np.hstack((size // 2 - 1 - np.arange(size//2), size // 2 + np.arange(size//2)))
col_inds = np.hstack((np.arange(size//2), size//2 + np.arange(size//2)))
else:
row_inds = np.hstack((np.arange(size//2), size//2 + np.arange(size//2)))
col_inds = np.hstack((np.arange(size//2) * 2, np.arange(size//2) * 2 + 1))
result = csr_matrix((np.ones(row_inds.size), (row_inds, col_inds)))
return result
def shared_logits_permutation(num_factors, choices):
"""
:param num_factors:
:param choices: array of three bool
:return:
"""
permutations = []
for i in range(2, num_factors + 1):
block = seye(2 ** i)
if choices[0]:
block = block @ perm_type(i, 0)
if choices[1]:
block = block @ perm_type(i, 1)
if choices[2]:
block = block @ perm_type(i, 2)
perm = skron(seye(2 ** (num_factors - i)), block)
permutations.append(perm)
return permutations
def get_permutation_matrix(num_factors, perm_name):
"""
:param num_factors:
:param perm_name: str, 000, 001, ..., 111
:return:
"""
if perm_name.isnumeric():
choices = [int(char) for char in perm_name]
p_list = shared_logits_permutation(num_factors, choices)
p = csr_matrix(Faust(p_list).toarray()) # TODO: keep csr along the whole product
else:
raise TypeError("perm_name must be numeric")
return p
# print(list(get_permutation_matrix(int(np.log2(M.shape[0])),
# perm_name).indices+1 \
# for perm_name in ["000", "001", "010", "011", "100",
# "101", "110", "111"]))
permutations = [get_permutation_matrix(int(np.log2(M.shape[0])),
perm_name).indices \
for perm_name in ["000", "001", "010", "011", "100", "101", "110", "111"]]
return butterfly(M, type, permutations)
if isinstance(perm, str):
if perm == 'bitrev':
P = bitrev_perm(M.shape[1])
return butterfly(M, type, perm=P.indices)
elif perm == 'default_8':
# the three modified functions below were originally extracted from the 3 clause-BSD code hosted here: https://github.com/leonzheng2/butterfly
# please look the header license here https://github.com/leonzheng2/butterfly/blob/main/src/utils.py
def perm_type(i, type):
"""
Type 0 is c in paper. Type 1 is b in paper. Type 2 is a in paper.
:param i:
:param type:
:return:
"""
size = 2 ** i
if type == 0:
row_inds = np.hstack((np.arange(size//2), size//2 + np.arange(size//2)))
col_inds = np.hstack((np.arange(size//2), size - 1 - np.arange(size//2)))
elif type == 1:
row_inds = np.hstack((size // 2 - 1 - np.arange(size//2), size // 2 + np.arange(size//2)))
col_inds = np.hstack((np.arange(size//2), size//2 + np.arange(size//2)))
else:
row_inds = np.hstack((np.arange(size//2), size//2 + np.arange(size//2)))
col_inds = np.hstack((np.arange(size//2) * 2, np.arange(size//2) * 2 + 1))
result = csr_matrix((np.ones(row_inds.size), (row_inds, col_inds)))
return result
def shared_logits_permutation(num_factors, choices):
"""
:param num_factors:
:param choices: array of three bool
:return:
"""
permutations = []
for i in range(2, num_factors + 1):
block = seye(2 ** i)
if choices[0]:
block = block @ perm_type(i, 0)
if choices[1]:
block = block @ perm_type(i, 1)
if choices[2]:
block = block @ perm_type(i, 2)
perm = skron(seye(2 ** (num_factors - i)), block)
permutations.append(perm)
return permutations
def get_permutation_matrix(num_factors, perm_name):
"""
:param num_factors:
:param perm_name: str, 000, 001, ..., 111
:return:
"""
if perm_name.isnumeric():
choices = [int(char) for char in perm_name]
p_list = shared_logits_permutation(num_factors, choices)
p = csr_matrix(Faust(p_list).toarray()) # TODO: keep csr along the whole product
else:
raise TypeError("perm_name must be numeric")
return p
# print(list(get_permutation_matrix(int(np.log2(M.shape[0])),
# perm_name).indices+1 \
# for perm_name in ["000", "001", "010", "011", "100",
# "101", "110", "111"]))
permutations = [get_permutation_matrix(int(np.log2(M.shape[0])),
perm_name).indices \
for perm_name in ["000", "001", "010", "011", "100", "101", "110", "111"]]
return butterfly(M, type, permutations)
elif isinstance(perm, (list, tuple)) and isinstance(perm[0], (list, tuple,
np.ndarray)):
# loop on each perm and keep the best approximation
......@@ -1460,7 +1484,7 @@ def butterfly(M, type="bbtree", perm=None):
P = csr_matrix((np.ones(row_inds.size), (row_inds, p)))
F = butterfly(M, type, p)
# compute error
error = np.linalg.norm(F-M)/Faust(M).norm()
error = np.linalg.norm(F.toarray()-M)/Faust(M).norm()
# print(error)
if error < best_err:
best_err = error
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment