Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 892b8f02 authored by hhakim's avatar hhakim
Browse files

Correct permutation matrix to indices conversion in pyfaust and matfaust wrappers.

parent 182fb4f1
No related branches found
No related tags found
No related merge requests found
......@@ -140,14 +140,14 @@ function F = butterfly(M, varargin)
pchoices = {'000', '001', '010', '011', '100', '101', '110', '111'};
for i=1:8
P = get_permutation_matrix(floor(log2(size(M, 1))), pchoices{i});
[permutations{i}, ~, ~] = find(P.'); % don't get the column indices directly because it would always 1 to size(P, 1) (indeed P is in CSC format), rather get them in the proper order (row 0 to size(P, 1)) by getting the row indices of P transpose
[permutations{i}, ~, ~] = find(P);
permutations{i} = permutations{i}.'; % just for readibility in case of printing
end
F = matfaust.fact.butterfly(M, 'type', type, 'perm', permutations);
return;
elseif strcmp(perm, 'bitrev')
P = bitrev_perm(size(M, 2));
[perm, ~, ~] = find(P.'); % cf. comments above
[perm, ~, ~] = find(P);
perm = perm.';
F = matfaust.fact.butterfly(M, 'type', type, 'perm', perm);
return;
......
......@@ -117,15 +117,15 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
newMxGetData(ptr_data, perm_inds);
unsigned int *row_ids = new unsigned int[perm_size];
unsigned int *ucol_ids = new unsigned int[perm_size];
unsigned int *col_ids = new unsigned int[perm_size];
for(int i=0;i < perm_size;i++)
ucol_ids[i] = (unsigned int) ptr_data[i]-1; // matlab is one-base indexed
std::iota(row_ids, row_ids+perm_size, 0);
row_ids[i] = (unsigned int) ptr_data[i]-1; // matlab is one-base indexed
std::iota(col_ids, col_ids+perm_size, 0);
SCALAR *ones = new SCALAR[perm_size];
std::fill(ones, ones+perm_size, SCALAR(1));
perm_mat = new Faust::MatSparse<SCALAR, Cpu>(row_ids, ucol_ids, ones, perm_size, perm_size, perm_size);
perm_mat = new Faust::MatSparse<SCALAR, Cpu>(row_ids, col_ids, ones, perm_size, perm_size, perm_size);
delete[] row_ids;
delete[] ucol_ids;
delete[] col_ids;
delete[] ones;
}
}
......
......@@ -1423,12 +1423,14 @@ def butterfly(M, type="bbtree", perm=None):
2019, pp. 1517–1527, PMLR
"""
from pyfaust.tools import bitrev_perm
def perm2indices(P):
return P.T.nonzero()[1]
is_real = np.empty((1,))
M = _check_fact_mat('butterfly()', M, is_real)
if isinstance(perm, str):
if perm == 'bitrev':
P = bitrev_perm(M.shape[1])
return butterfly(M, type, perm=P.indices)
return butterfly(M, type, perm=perm2indices(P))
elif perm == 'default_8':
# the three modified functions below were originally extracted from the 3 clause-BSD code hosted here: https://github.com/leonzheng2/butterfly
# please look the header license here https://github.com/leonzheng2/butterfly/blob/main/src/utils.py
......@@ -1485,12 +1487,12 @@ def butterfly(M, type="bbtree", perm=None):
raise TypeError("perm_name must be numeric")
return p
# print(list(get_permutation_matrix(int(np.log2(M.shape[0])),
# perm_name).indices+1 \
# print(list(perm2indices(get_permutation_matrix(int(np.log2(M.shape[0])),
# perm_name))+1 \
# for perm_name in ["000", "001", "010", "011", "100",
# "101", "110", "111"]))
permutations = [get_permutation_matrix(int(np.log2(M.shape[0])),
perm_name).indices \
permutations = [perm2indices(get_permutation_matrix(int(np.log2(M.shape[0])),
perm_name)) \
for perm_name in ["000", "001", "010", "011", "100", "101", "110", "111"]]
return butterfly(M, type, permutations)
elif isinstance(perm, (list, tuple)) and isinstance(perm[0], (list, tuple,
......
......@@ -891,13 +891,12 @@ FaustCoreCpp<FPP>* butterfly_hierarchical(FPP* mat, unsigned int num_rows, unsig
if(perm != nullptr)
{
//Faust::MatSparse<FPP,Cpu>::MatSparse(const unsigned int* rowidx, const unsigned int* colidx, const FPP* values, const faust_unsigned_int dim1_, const faust_unsigned_int dim2_, faust_unsigned_int nnz)
unsigned int *row_ids = new unsigned int[num_cols];
unsigned int *ucol_ids = (unsigned int*) perm;
std::iota(row_ids, row_ids+num_cols, 0);
unsigned int *col_ids = new unsigned int[num_cols];
std::iota(col_ids, col_ids+num_cols, 0);
FPP *ones = new FPP[num_cols];
std::fill(ones, ones+num_cols, FPP(1));
perm_mat = new Faust::MatSparse<FPP, Cpu>((unsigned int*)row_ids, (unsigned int*)perm, ones, num_cols, num_cols, num_cols);
delete[] row_ids;
perm_mat = new Faust::MatSparse<FPP, Cpu>((unsigned int*) perm, (unsigned int*) col_ids, ones, num_cols, num_cols, num_cols);
delete[] col_ids;
delete[] ones;
}
FaustCoreCpp<FPP>* core = nullptr;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment