Mentions légales du service

Skip to content
Snippets Groups Projects
Commit f8f26bee authored by hhakim's avatar hhakim
Browse files

Enable matfaust.basis native GPU impl (argument dev).

parent c07ea42f
Branches
Tags
No related merge requests found
......@@ -6,6 +6,7 @@
%> @param K the degree of the last polynomial, i.e. the K+1 first polynomials are built.
%> @param basis_name 'chebyshev', and others yet to come.
%> @param 'T0', matrix (optional): a sparse matrix to replace the identity as a 0-degree polynomial of the basis.
%> @param 'dev', str (optional): the device to instantiate the returned Faust ('cpu' or 'gpu').
%>
%> @retval F the Faust of the basis composed of the K+1 orthogonal polynomials.
%>
......@@ -58,6 +59,7 @@ function F = basis(L, K, basis_name, varargin)
T0_is_set = false;
T0 = []; % no T0 by default
argc = length(varargin);
dev = 'cpu';
if(argc > 0)
for i=1:2:argc
if(argc > i)
......@@ -86,7 +88,7 @@ function F = basis(L, K, basis_name, varargin)
end
end
mex_args = {basis_name, L, K};
mex_args = {basis_name, L, K, startsWith(dev, 'gpu')};
if(T0_is_set)
if(size(T0,1) ~= size(L,2))
......@@ -98,7 +100,6 @@ function F = basis(L, K, basis_name, varargin)
mex_args = [mex_args {T0}];
end
if(strcmp(basis_name, 'chebyshev'))
if(is_real)
core_obj = mexPolyReal(mex_args{:});
......
......@@ -57,24 +57,39 @@ void chebyshev(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
if(nrhs >= 2)
{
get_sp_mat(prhs[1], &L);
bool on_gpu = false;
if(nrhs >= 3)
K = (int) mxGetScalar(prhs[2]);
else
mexErrMsgTxt("chebyshev mex error: K argument is mandatory");
// optional T0 matrix
if(nrhs >= 4)
{
get_sp_mat(prhs[3], &opt_T0);
on_gpu = (bool) mxGetScalar(prhs[3]);
}
// optional T1 matrix
if(nrhs >= 5)
{
get_sp_mat(prhs[4], &opt_T0);
T0 = &opt_T0;
}
auto thp = Faust::basisChebyshev(&L, K, T0);
if(on_gpu)
{
#ifdef USE_GPU_MOD
Faust::enable_gpu_mod();
#endif
}
auto thp = Faust::basisChebyshev(&L, K, T0, on_gpu);
if(thp) // not NULL
plhs[0] = convertPtr2Mat<Faust::TransformHelper<FPP, Cpu> >(thp);
else {
else
{
plhs[0] = mxCreateDoubleMatrix(1,1,mxREAL);
double* ptr_out = (double*) mxGetData(plhs[0]);
ptr_out[0] = (double) 0;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment