Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 3fe8d601 authored by hhakim's avatar hhakim
Browse files

Add the default_8 option (8 deft permutations) to matfaust.fact.butterfly.

parent d0f7bf05
No related branches found
No related tags found
No related merge requests found
...@@ -47,22 +47,33 @@ function F = butterfly(M, varargin) ...@@ -47,22 +47,33 @@ function F = butterfly(M, varargin)
end end
end end
end end
if iscell(perm) % perm is a cell of arrays, each one defining a permutation to test if strcmp(perm, 'default_8')
% evaluate butterfly factorisation using the permutations and permutations = cell(1, 8);
% keep the best Faust pchoices = {'000', '001', '010', '011', '100', '101', '110', '111'};
min_err = inf; for i=1:8
nM = norm(M); P = get_permutation_matrix(floor(log2(size(M, 1))), pchoices{i});
for i=1:length(perm) [permutations{i}, ~, ~] = find(P.'); % don't get the column indices directly because it would always 1 to size(P, 1) (indeed P is in CSC format), rather get them in the proper order (row 0 to size(P, 1)) by getting the row indices of P transpose
F = matfaust.fact.butterfly(M, 'type', type, 'perm', perm{i}); permutations{i} = permutations{i}.'; % just for readibility in case of printing
err = norm(full(F)-M)/nM; end
if err < min_err F = matfaust.fact.butterfly(M, 'type', type, 'perm', permutations);
min_err = err; return;
best_F = F; elseif iscell(perm) % perm is a cell of arrays, each one defining a permutation to test
end % evaluate butterfly factorisation using the permutations and
end % keep the best Faust
F = best_F; min_err = inf;
return; nM = norm(M, 'fro');
end for i=1:length(perm)
perm{i}
F = matfaust.fact.butterfly(M, 'type', type, 'perm', perm{i});
err = norm(full(F)-M, 'fro')/nM
if err < min_err
min_err = err;
best_F = F;
end
end
F = best_F;
return;
end
if(strcmp(type, 'right')) if(strcmp(type, 'right'))
type = 1; type = 1;
elseif(strcmp(type, 'left')) elseif(strcmp(type, 'left'))
...@@ -100,3 +111,50 @@ function b = is_cell_arrays_of_indices(value, M) ...@@ -100,3 +111,50 @@ function b = is_cell_arrays_of_indices(value, M)
end end
end end
end end
function perm = perm_type(i, type)
%% Type 0 is c in paper. Type 1 is b in paper. Type 2 is a in paper.
size = 2^i;
size_o2 = size / 2;
switch(type)
case 0
row_inds = [1:size_o2, size_o2+1:size];
col_inds = [1:size_o2, size:-1:size-size/2+1];
case 1
row_inds = [size_o2:-1:1, size_o2 + (1:size_o2)];
col_inds = [1:size_o2, size_o2 + (1:size_o2)];
case 2
row_inds = [1:size_o2, size_o2 + (1:size_o2)];
col_inds = [1:2:size-1, 2:2:size];
otherwise
error('perm_type received an invalid value for type argument')
end
m = numel(row_inds);
n = m;
nnz = m;
perm = sparse(row_inds, col_inds, ones(1, m), m, n, nnz);
end
function permutations = shared_logits_permutation(num_factors, choices)
% choices: array of three logical-s
permutations = {};
for i=2:num_factors
block = speye(2^i);
for j=1:3
if choices(j)
block = block * perm_type(i, j-1);
end
end
permutations = [permutations, {kron(speye(2 ^(num_factors - i)), block)}];
end
end
function p = get_permutation_matrix(num_factors, perm_name)
% perm_name: str 000, 001, ..., 111
choices = zeros(1, 3);
for i=1:length(perm_name)
choices(i) = str2double(perm_name(i));
end
permutations = shared_logits_permutation(num_factors, choices);
p = sparse(full(matfaust.Faust(permutations)));
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment