Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 15eed53a authored by hhakim's avatar hhakim
Browse files

Fix matlab cpp compilation error in faust_optimize_time_prod (GPU version).

parent 3840a7eb
Branches
Tags
No related merge requests found
......@@ -41,8 +41,9 @@ namespace Faust
MatDense<FPP,GPU2> get_product(int prod_mod=-1);
void get_product(MatDense<FPP,GPU2>& M, int prod_mod=-1);
void get_product(MatDense<FPP,Cpu>& M, int prod_mod=-1);
MatDense<FPP,GPU2> multiply(const MatDense<FPP,GPU2> &A);
MatDense<FPP,GPU2> multiply(const MatSparse<FPP,GPU2> &A) { return multiply(MatDense<FPP,GPU2>(A));}
MatDense<FPP, GPU2> multiply(const MatDense<FPP,GPU2> &A);
MatDense<FPP, GPU2> multiply(const MatSparse<FPP,GPU2> &A) { return multiply(MatDense<FPP,GPU2>(A));}
MatDense<FPP, Cpu> multiply(const MatSparse<FPP,Cpu> &A) { return multiply(MatDense<FPP, Cpu>(A));} // TODO: avoid CPU copy to dense
MatDense<FPP,Cpu> multiply(const MatDense<FPP,Cpu> &A);
TransformHelper<FPP,GPU2>* multiply(const FPP& a);
TransformHelper<FPP,GPU2>* multiply(const TransformHelper<FPP,GPU2>*);
......
......@@ -29,14 +29,23 @@ void faust_optimize_time_prod(const mxArray **prhs, const int nrhs, mxArray **pl
bool inplace = (bool) mxGetScalar(prhs[3]);
int nsamples = (int) mxGetScalar(prhs[4]);
const mxArray *mat = nullptr;
const MatGeneric<SCALAR, Cpu>* matGen = nullptr;
const MatGeneric<SCALAR, DEV>* matGen = nullptr;
Faust::MatSparse<SCALAR,Cpu> sp_mat;
Faust::MatDense<SCALAR,Cpu> ds_mat;
Faust::MatSparse<SCALAR,DEV> dsp_mat;
Faust::MatDense<SCALAR,DEV> dds_mat;
mat = prhs[5];
if(mxIsSparse(mat))
{
mxArray2FaustspMat<SCALAR>(mat, sp_mat);
matGen = &sp_mat;
if(DEV != Cpu)
{
// GPU mat
dsp_mat = sp_mat;
matGen = &dsp_mat;
}
else
matGen = (Faust::MatSparse<SCALAR, DEV>*) &sp_mat; // cast to avoid compil error but the bad case (DEV == GPU2) will never occur
}
else
{
......@@ -47,7 +56,13 @@ void faust_optimize_time_prod(const mxArray **prhs, const int nrhs, mxArray **pl
ds_mat.resize(mat_nrows, mat_ncols);
memcpy(ds_mat.getData(), ptr_data, mat_ncols*mat_ncols*sizeof(SCALAR));
delete [] ptr_data;
matGen = &ds_mat;
if(DEV == Cpu)
matGen = (Faust::MatDense<SCALAR, DEV>*) &ds_mat; // cast to avoid compil error but the bad case (DEV == GPU2) will never occur
else
{
dds_mat = ds_mat;
matGen = &dds_mat;
}
}
Faust::TransformHelper<SCALAR,DEV>* th = core_ptr->optimize_time_prod(matGen, transp, inplace, nsamples);
if(inplace /*th == nullptr*/)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment