Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 27aa8306 authored by hhakim's avatar hhakim
Browse files

Correct butterfly permutation feature to obtain a F that approximate M instead of M*P

parent ef26d239
No related branches found
No related tags found
No related merge requests found
......@@ -681,6 +681,15 @@ namespace Faust
th = butterfly_hierarchical(A, support, dir);
for(int i = 0; i < (P == nullptr?support.size():support.size()-1);i++)
delete support[i];
if(P != nullptr)
{
//TODO: maybe a swapcols on th would be wiser/quicker
MatSparse<FPP, Cpu> Pt(*P);
Pt.transpose();
auto last_fac = dynamic_cast<const MatSparse<FPP, Cpu>*>(th->get_gen_fact(th->size()-1));
Pt.multiplyRight(*last_fac);
th->replace(new MatSparse<FPP, Cpu>(Pt), th->size()-1);
}
return th;
}
......
%==========================================================================
%> @brief Factorizes the matrix M according to a butterfly support.
%> @param 'type', str: the type of factorization 'right'ward, 'left'ward or 'bbtree'.
%> More precisely: if 'left' (resp. 'right') is used then at each stage of the
%> factorization the most left factor (resp. the most right factor) is split in two.
......@@ -10,7 +9,7 @@
%>
%> @param 'perm', value four kind of values are possible for this argument (Note that this argument is made only for the bbtree type of factorization).
%>
%> 1. perm is an array of column indices of the permutation matrix P which is such that the returned Faust F is the approximation of M*P and F*P.' the approximation of M.
%> 1. perm is an array of column indices of the permutation matrix P which is such that the returned Faust is F = G * P.' where G is the Faust butterfly approximation of M*P.
%> 2. perm is a cell array of arrays of permutation column indices as defined in 1. In that case, all permutations passed to the function are used as explained in 1, each one producing a Faust, the best one (that is the best approximation of M) is kept and returned by butterfly.
%> 3. perm is 'default_8', this is a particular case of 2. Eight default permutations are used. For the definition of those permutations please refer to [1].
%> 4. By default this argument is empty, no permutation is used.
......@@ -62,17 +61,10 @@
%> - FACTOR 3 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 4 (double) SPARSE, size 32x32, density 0.0625, nnz 64
%> - FACTOR 5 (double) SPARSE, size 32x32, density 0.03125, nnz 32
%> >> [Jbest, ~, ~] = find(factors(F, 6).');
%> >> all(all(Jbest.' == J1))
%>
%> ans =
%>
%> logical
%> 1
%> >> # here the J1 permutation is the best one
%> @endcode
%>
%>
%>
%> <b>Reference: [1]</b> Leon Zheng, Elisa Riccietti, and Remi Gribonval, <a href="https://arxiv.org/pdf/2110.01230.pdf">Hierarchical Identifiability in Multi-layer Sparse Matrix Factorization</a>
%==========================================================================
function F = butterfly(M, varargin)
......@@ -116,10 +108,8 @@ function F = butterfly(M, varargin)
for i=1:length(perm)
% perm{i}
m = numel(perm{i});
P = sparse(1:m, perm{i}, ones(1, m), m, m);
MP = M*P;
F = matfaust.fact.butterfly(MP, 'type', type, 'perm', perm{i});
err = norm(F*P'-M, 'fro')/nM;
F = matfaust.fact.butterfly(M, 'type', type, 'perm', perm{i});
err = norm(F-M, 'fro')/nM;
if err < min_err
min_err = err;
best_F = F;
......
......@@ -1340,8 +1340,7 @@ def butterfly(M, type="bbtree", perm=None):
perm: four kind of values are possible for this argument (Note that this argument is made only for the bbtree type of
factorization).
1. perm is a list of column indices of the permutation matrix P which is such that
the returned Faust F is the approximation of M@P and F@P.T the
approximation of M.
the returned Faust is F = G@P.T where G is the Faust butterfly approximation of M@P.
2. perm is a list of lists of permutation column indices as defined
in 1. In that case, all permutations passed to the function are
used as explained in 1, each one producing a Faust, the best one
......@@ -1373,9 +1372,8 @@ def butterfly(M, type="bbtree", perm=None):
>>> permutations = list(permutations(J))
>>> J2 = list(permutations[randint(0, len(permutations)-1)])
>>> F = butterfly(H, type='bbtree', perm=[J1, J2])
>>> np.allclose(F.factors(3).indices, J1)
True
>>> # here the best permutation is J1
>>> # or to to use the 8 default permutations (keeping the best approximation resulting Faust)
>>> F = butterfly(H, type='bbtree', perm='default_8')
Reference:
[1] Quoc-Tung Le, Léon Zheng, Elisa Riccietti, Rémi Gribonval. Fast
......@@ -1460,10 +1458,9 @@ def butterfly(M, type="bbtree", perm=None):
# print(p)
row_inds = np.arange(len(p))
P = csr_matrix((np.ones(row_inds.size), (row_inds, p)))
MP = M@P
F = butterfly(MP, type, p)
F = butterfly(M, type, p)
# compute error
error = np.linalg.norm(F@P.T-M)/Faust(M).norm()
error = np.linalg.norm(F-M)/Faust(M).norm()
# print(error)
if error < best_err:
best_err = error
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment