Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 70a87997 authored by hhakim's avatar hhakim
Browse files

Update matfaust.Faust.factors/left/right and mex code to handle retrieving factors as Faust.

Avoid round trip between matlab and cpp with costful copies and not desired conversions from MatButterfly/MatPerm to MatSparse.
parent f3f6826f
No related branches found
No related tags found
No related merge requests found
......@@ -1720,9 +1720,9 @@ classdef Faust
%=====================================================================
%> @brief Returns the i-th factor or a new Faust composed of F factors whose indices are listed in indices.
%>
%> @note Factors are copied in memory.
%>
%>
%> @note Factors are not copied in memory if subset greater than one is asked and doesn't need conversion.
%>
%> @b Usage
%>
%> &nbsp;&nbsp;&nbsp; @b factor = factors(F, i) returns the i-th factor of F.<br/>
......@@ -1744,27 +1744,24 @@ classdef Faust
%> @endcode
%> <p>@b See @b also Faust.numfactors
%=====================================================================
function factors = factors(F, varargin)
factors = cell(1, size(varargin{1},2));
for j=1:length(factors)
i = varargin{1};
if(j < length(factors) && i(j+1) - i(j) ~= 1)
error('Indices must be contiguous.')
end
i = i(j);
if (~isa(i,'double'))
error('factors second argument (indice) must either be real positive integers or logicals.');
end
if (floor(i) ~= i)
error('factors second argument (indice) must either be real positive integers or logicals.');
function factors = factors(F, ids)
ids_err = 'factors indices must either be real positive integers or logicals.';
if ~ strcmp(class(ids), 'double')
error(ids_err);
end
for j=1:length(ids)
i = ids(j);
if floor(i) ~= i
error(ids_err);
end
factors{j} = call_mex(F, 'factors', i);
end
if(length(factors) > 1)
factors = matfaust.Faust(factors, 'dev', F.dev);
nargs = numel(ids);
if nargs > 1
factors = matfaust.Faust(F, call_mex(F, 'factors', uint64(ids-1)));
elseif nargs == 1
factors = call_mex(F, 'factors', uint64(ids-1));
else
factors = factors{j};
error('Empty range of factors')
end
end
......
......@@ -10,32 +10,38 @@ template <typename SCALAR, FDevice DEV>
void faust_factors(const mxArray **prhs, const int nrhs, mxArray **plhs, const int nlhs)
{
Faust::TransformHelper<SCALAR,DEV>* core_ptr = convertMat2Ptr<Faust::TransformHelper<SCALAR,DEV> >(prhs[1]);
if (nlhs > 1 || nrhs != 3)
{
mexErrMsgTxt("factors : incorrect number of arguments.");
}
int id = (int) (mxGetScalar(prhs[2])-1);
auto type = core_ptr->get_fact_type(id);
if(core_ptr->is_fact_sparse(id))
plhs[0] = transformFact2SparseMxArray(id,core_ptr);
//
// plhs[0] = FaustSpMat2mxArray(*dynamic_cast<const Faust::MatSparse<SCALAR,Cpu>*>(core_ptr->get_gen_fact(id)));
else if(core_ptr->is_fact_dense(id))
plhs[0] = transformFact2FullMxArray(id,core_ptr);
// plhs[0] = FaustMat2mxArray(*dynamic_cast<const Faust::MatDense<SCALAR,Cpu>*>(core_ptr->get_gen_fact(id)));
else if(core_ptr->is_fact_bsr(id))
{
plhs[0] = bsr_mat_to_sp_mat(id, core_ptr);
auto ids = static_cast<unsigned long int*>(mxGetData(prhs[2]));
auto n_ids = mxGetNumberOfElements(prhs[2]);
if(n_ids == 1)
{ // asked a single factor
int id = ids[0];
auto type = core_ptr->get_fact_type(id);
if(core_ptr->is_fact_sparse(id))
plhs[0] = transformFact2SparseMxArray(id, core_ptr);
else if(core_ptr->is_fact_dense(id))
plhs[0] = transformFact2FullMxArray(id, core_ptr);
else if(core_ptr->is_fact_bsr(id))
plhs[0] = bsr_mat_to_sp_mat(id, core_ptr);
else if(type == 4)
// MatButterfly
plhs[0] = butterfly_mat_to_sp_mat(id, core_ptr);
else if(type == 5)
// MatPerm
plhs[0] = perm_mat_to_sp_mat(id, core_ptr);
else
mexErrMsgTxt("Unhandled type of matrix");
}
else if(type == 4)
// MatButterfly
plhs[0] = butterfly_mat_to_sp_mat(id, core_ptr);
else if(type == 5)
// MatPerm
plhs[0] = perm_mat_to_sp_mat(id, core_ptr);
else
mexErrMsgTxt("Unhandled type of matrix");
{ // asked a group of factors // return a Faust
Faust::TransformHelper<SCALAR, DEV>* th = core_ptr->factors(ids, n_ids);
plhs[0] = convertPtr2Mat<Faust::TransformHelper<SCALAR,DEV> >(th);
}
}
......
......@@ -39,7 +39,7 @@ template<class base> inline mxArray *convertPtr2Mat(base *ptr)
template<class base> inline class_handle<base> *convertMat2HandlePtr(const mxArray *in)
{
if (mxGetNumberOfElements(in) != 1 || mxGetClassID(in) != mxUINT64_CLASS || mxIsComplex(in))
mexErrMsgTxt("Input must be a real uint64 scalar.");
mexErrMsgTxt("Input must be a uint64 scalar.");
class_handle<base> *ptr = reinterpret_cast<class_handle<base> *>(*((uint64_t *)mxGetData(in)));
if (!ptr->isValid())
mexErrMsgTxt("Handle not valid.");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment