Mentions légales du service

Skip to content
Snippets Groups Projects
Commit d534cc08 authored by hhakim's avatar hhakim
Browse files

Implement diag_opt option for matfaust.fact.butterfly.

 #275
parent 5bfb26f9
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,7 @@
%> If 'bbtree' is used then the matrix is factorized according to a Balanced
%> Binary Tree (which is faster as it allows parallelization).
%> @param 'perm', value five kinds of values are possible for this argument.
%> @param 'diag_opt', bool: if true then the returned Faust is optimized using matfaust.opt_butterfly_faust.
%>
%> 1. perm is an array of column indices of the permutation matrix P which is such that the returned Faust is F = B * P where B is the Faust butterfly approximation of M*P.'. If the array of indices is not a valid permutation the behaviour is undefined (however an invalid size or an out of bound index raise an exception).
%> 2. perm is a cell array of arrays of permutation column indices as defined in 1. In that case, all permutations passed to the function are used as explained in 1, each one producing a Faust, the best one (that is the best approximation of M in the Frobenius norm) is kept and returned by butterfly.
......@@ -161,6 +162,7 @@ function F = butterfly(M, varargin)
nargin = length(varargin);
type = 'right';
perm = [];
diag_opt = false;
if(nargin > 0)
for i=1:2:nargin
switch(varargin{i})
......@@ -176,6 +178,12 @@ function F = butterfly(M, varargin)
else
perm = varargin{i+1};
end
case 'diag_opt'
if(nargin < i+1 || ~ islogical(varargin{i+1}))
error('diag_opt keyword argument is not followed by a logical')
else
diag_opt = varargin{i+1};
end
end
end
end
......@@ -187,13 +195,13 @@ function F = butterfly(M, varargin)
[permutations{i}, ~, ~] = find(P);
permutations{i} = permutations{i}.'; % just for readibility in case of printing
end
F = matfaust.fact.butterfly(M, 'type', type, 'perm', permutations);
F = matfaust.fact.butterfly(M, 'type', type, 'perm', permutations, 'diag_opt', diag_opt);
return;
elseif strcmp(perm, 'bitrev')
P = bitrev_perm(size(M, 2));
[perm, ~, ~] = find(P);
perm = perm.';
F = matfaust.fact.butterfly(M, 'type', type, 'perm', perm);
F = matfaust.fact.butterfly(M, 'type', type, 'perm', perm, 'diag_opt', diag_opt);
return;
elseif iscell(perm) % perm is a cell of arrays, each one defining a permutation to test
% evaluate butterfly factorisation using the permutations and
......@@ -203,7 +211,7 @@ function F = butterfly(M, varargin)
for i=1:length(perm)
% perm{i}
m = numel(perm{i});
F = matfaust.fact.butterfly(M, 'type', type, 'perm', perm{i});
F = matfaust.fact.butterfly(M, 'type', type, 'perm', perm{i}, 'diag_opt', diag_opt);
err = norm(full(F)-M, 'fro')/nM;
if err < min_err
min_err = err;
......@@ -221,13 +229,13 @@ function F = butterfly(M, varargin)
type = 2;
end
if(strcmp(class(M), 'single'))
core_obj = mexButterflyRealFloat(M, type, perm);
core_obj = mexButterflyRealFloat(M, type, perm, ~ diag_opt); % last arg is mul_perm
F = Faust(core_obj, isreal(M), 'cpu', 'float', true); % 4th arg is for copying
else
if(isreal(M))
core_obj = mexButterflyReal(M, type, perm);
core_obj = mexButterflyReal(M, type, perm, ~ diag_opt);
else
core_obj = mexButterflyCplx(M, type, perm);
core_obj = mexButterflyCplx(M, type, perm, ~ diag_opt);
end
F = Faust(core_obj, isreal(M), 'cpu', 'double', true); % 4th arg is for copying
end
......
......@@ -85,12 +85,16 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
ButterflyFactDir dir = RIGHT;
if(nrhs >= 2)
dir = static_cast<ButterflyFactDir>(static_cast<int>(mxGetScalar(prhs[1])));
bool mul_perm = true;
if(nrhs >= 3)
{
#ifndef MX_HAS_INTERLEAVED_COMPLEX
#error perm argument compiling isn't supported for Matlab versions prior to R2018a
#endif
auto perm_inds = prhs[2];
if(nrhs >= 4)
mul_perm = (bool) mxGetScalar(prhs[3]);
if(mxIsSparse(perm_inds))
mexErrMsgTxt("The matrix of indices must be dense");
else
......@@ -131,7 +135,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
}
}
F = butterfly_hierarchical(matrix, dir, perm_mat);
F = butterfly_hierarchical(matrix, dir, perm_mat, mul_perm);
plhs[0] = convertPtr2Mat<Faust::TransformHelper<SCALAR, Cpu>>(F);
if(perm_mat != nullptr)
delete perm_mat;
......
......@@ -77,7 +77,7 @@ classdef FaustCore < handle
if copy && ~ strcmp(this.dtype, 'float')
% don't use the copy for 'float' because matlab doesn't support
% single precision sparse matrices
onGPU = startsWith(this.dev, 'gpu')
onGPU = startsWith(this.dev, 'gpu');
nf = call_mex(this, 'numfactors');
facts = cell(1, nf);
for i=1:nf
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment