Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 39ea4c04 authored by Nicolas Bellot's avatar Nicolas Bellot Committed by hhakim
Browse files

wrapper matlab check scalar compatibility (real only compatible with real, not complex)

parent 71834d5d
Branches
Tags
No related merge requests found
......@@ -47,7 +47,7 @@ threshold = 10^(-5);
disp('****** TEST MATLAB_FAUST ******* ');
disp('CONSTRUCTOR ');
disp('test 1 : ');
disp('test 1 : invalid factor size ');
test_pass = 0;
expected_err_message='Faust::Transform<FPP,Cpu> : check_factors_validity : dimensions of the factors mismatch';
factors=cell(1,nb_fact);
......@@ -74,7 +74,7 @@ end
disp('Ok');
disp('test 2 : ');
disp('test 2 : invalid factor empty');
test_pass = 0;
expected_err_message='concatMatGeneric : empty matlab matrix';
factors=cell(1,nb_fact); % each cell is empty, must contained a matrix
......@@ -100,7 +100,7 @@ disp('Ok');
disp('test 3 : ');
disp('test 3 : invalid factor type (cell)');
test_pass = 0;
expected_err_message='getFaustMat :input matrix format must be single or double';
factors=cell(1,1);
......@@ -124,7 +124,7 @@ end
disp('Ok');
disp('test 4 : ');
disp('test 4 : subsasgn not implemented');
test_pass = 0;
expected_err_message='function not implemented for Faust class';
F=Faust({ones(5,4),ones(4,7)});
......@@ -145,5 +145,77 @@ end
disp('Ok');
disp('test 5 : multiplication scalar compatibility (real/complex)');
% real scalar Faust
test_pass = 0;
expected_err_message='Multiply : Wrong scalar compatibility (real/complex)';
nbRowF = 3;
nbColF = 4;
F_real = Faust({rand(nbRowF,nbColF)});
% complex scalar vector
nbColVec = 7;
%% avoid use of the variable i for imaginary unit to avoid conflict with the i of a loop for instance
%% prefer the expression 1i (imaginary unit)
complex_vector = rand(nbColF,nbColVec)+ 1i * rand(nbColF,nbColVec);
try
y = F_real * complex_vector;
catch ME
if strcmp(ME.message,expected_err_message)
test_pass = 1;
else
error([ 'error with a wrong message : ' ME.message ' must be : ' expected_err_message ]);
end
end
disp('Ok');
disp('test 6 : invalid dense factor (scalar complex)');
test_pass = 0;
nbRowF = 3;
nbColF = 4;
expected_err_message='getFaustMat scalar type (complex/real) are not compatible';
factor_complex = rand(nbColF,nbColVec)+ 1i * rand(nbColF,nbColVec);
try
F_real = Faust({factor_complex});
catch ME
if strcmp(ME.message,expected_err_message)
test_pass = 1;
else
disp(ME.message);
error([ 'error with a wrong message : ' ME.message ' must be : ' expected_err_message ]);
end
end
disp('Ok');
disp('test 7 : invalid sparse factor (scalar complex)');
test_pass = 0;
nbRowF = 3;
nbColF = 4;
expected_err_message='getFaustspMat scalar type (complex/real) are not compatible';
factor_complex = sparse(rand(nbColF,nbColVec)+ 1i * rand(nbColF,nbColVec));
try
F_real = Faust({factor_complex});
catch ME
if strcmp(ME.message,expected_err_message)
test_pass = 1;
else
disp(ME.message);
error([ 'error with a wrong message : ' ME.message ' must be : ' expected_err_message ]);
end
end
disp('Ok');
......@@ -104,11 +104,7 @@ namespace Faust
* */
void check_factors_validity() const;
/** \brief
* check if the Transform has real scalar or complex scalar
* */
bool isReal() const;
/** \brief Constructor
* \param facts : Vector including dense matrix*/
Transform(const std::vector<Faust::MatDense<FPP,Cpu> >&facts);
......
......@@ -52,8 +52,7 @@
#include <fstream>
#include "faust_BlasHandle.h"
#include "faust_SpBlasHandle.h"
#include <complex>
#include <typeinfo>
......@@ -289,24 +288,7 @@ void Faust::Transform<FPP,Cpu>::updateNonZeros()
}
template<typename FPP>
bool Faust::Transform<FPP,Cpu>::isReal() const
{
bool isReal = (typeid(FPP) == typeid(double));
isReal = (isReal || (typeid(FPP) == typeid(float)) );
bool isComplex = (typeid(FPP) == typeid(std::complex<double>));
isComplex = (isComplex || (typeid(FPP) == typeid(std::complex<float>)) );
if ( (!isComplex) && (!isReal) )
{
handleError(m_className,"isReal : unknown type of scalar");
}
return isReal;
}
......
......@@ -91,7 +91,17 @@ namespace Faust
*/
virtual void faust_gemm(const Faust::MatDense<FPP,DEVICE> & B, Faust::MatDense<FPP,DEVICE> & C,const FPP & alpha, const FPP & beta, char typeA, char typeB)const=0;
/** \brief
* check if the LinearOperator has real scalar or complex scalar
* */
bool isReal() const;
};
}
#include "faust_LinearOperator.hpp"
#endif
#ifndef __FAUST_LINEAR_OPERATOR_HPP__
#define __FAUST_LINEAR_OPERATOR_HPP__
#include <complex>
#include <typeinfo>
#include "faust_exception.h"
template<typename FPP,Device DEVICE>
bool Faust::LinearOperator<FPP,DEVICE>::isReal() const
{
bool isReal = (typeid(FPP) == typeid(double));
isReal = (isReal || (typeid(FPP) == typeid(float)) );
bool isComplex = (typeid(FPP) == typeid(std::complex<double>));
isComplex = (isComplex || (typeid(FPP) == typeid(std::complex<float>)) );
if ( (!isComplex) && (!isReal) )
{
handleError("linearOperator","isReal : unknown type of scalar");
}
return isReal;
}
#endif
......@@ -334,163 +334,172 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
return;
}
if (!strcmp("multiply", cmd)) {
if (nlhs > 1 || nrhs != 4)
mexErrMsgTxt("Multiply: Unexpected arguments.");
mwSize nelem = mxGetNumberOfElements(prhs[3]);
if (nelem != 1)
mexErrMsgTxt("invalid char argument.");
// boolean flag to know if the faust si transposed
bool transpose_flag = (bool) mxGetScalar(prhs[3]);
char op;
if (transpose_flag)
op='T';
else
op='N';
if (!strcmp("multiply", cmd)) {
if (nlhs > 1 || nrhs != 4)
mexErrMsgTxt("Multiply: Unexpected arguments.");
mwSize nelem = mxGetNumberOfElements(prhs[3]);
if (nelem != 1)
mexErrMsgTxt("invalid char argument.");
// boolean flag to know if the faust si transposed
bool transpose_flag = (bool) mxGetScalar(prhs[3]);
char op;
if (transpose_flag)
op='T';
else
op='N';
// input matrix or vector from MATLAB
const mxArray * inMatlabMatrix = prhs[2];
const size_t nbRowA = mxGetM(prhs[2]);
const size_t nbColA = mxGetN(prhs[2]);
faust_unsigned_int nbRowOp_,nbColOp_;
(*core_ptr).setOp(op,nbRowOp_,nbColOp_);
const size_t nbRowOp = nbRowOp_;
const size_t nbColOp = nbColOp_;
const size_t nbRowB = nbRowOp;
const size_t nbColB = nbColA;
const size_t nbRowA = mxGetM(inMatlabMatrix);
const size_t nbColA = mxGetN(inMatlabMatrix);
faust_unsigned_int nbRowOp_,nbColOp_;
(*core_ptr).setOp(op,nbRowOp_,nbColOp_);
const size_t nbRowOp = nbRowOp_;
const size_t nbColOp = nbColOp_;
const size_t nbRowB = nbRowOp;
const size_t nbColB = nbColA;
// Check parameters
if (mxGetNumberOfDimensions(prhs[2]) != 2
|| nbRowA != nbColOp )
mexErrMsgTxt("Multiply: Wrong number of dimensions for the input vector or matrix (third argument).");
/** Check parameters **/
//check dimension match
if (mxGetNumberOfDimensions(inMatlabMatrix) != 2
|| nbRowA != nbColOp )
mexErrMsgTxt("Multiply : Wrong number of dimensions for the input vector or matrix (third argument).");
FFPP* ptr_data = NULL;
//check scalar type match (real/complex)
if ( !isScalarCompatible((*core_ptr),inMatlabMatrix) )
mexErrMsgTxt("Multiply : Wrong scalar compatibility (real/complex)");
const mxClassID V_CLASS_ID = mxGetClassID(prhs[2]);
const size_t NB_ELEMENTS = mxGetNumberOfElements(prhs[2]);
if(V_CLASS_ID == mxDOUBLE_CLASS)
{
double* ptr_data_tmp = static_cast<double*> (mxGetData(prhs[2]));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if(V_CLASS_ID == mxSINGLE_CLASS)
{
float* ptr_data_tmp = static_cast<float*> (mxGetData(prhs[2]));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if(V_CLASS_ID == mxINT8_CLASS)
{
char* ptr_data_tmp = static_cast<char*> (mxGetData(prhs[2]));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if(V_CLASS_ID == mxUINT8_CLASS)
{
unsigned char* ptr_data_tmp = static_cast<unsigned char*> (mxGetData(prhs[2]));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if(V_CLASS_ID == mxINT16_CLASS)
{
short* ptr_data_tmp = static_cast<short*> (mxGetData(prhs[2]));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if (V_CLASS_ID == mxUINT16_CLASS)
{
unsigned short* ptr_data_tmp = static_cast<unsigned short*> (mxGetData(prhs[2]));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if (V_CLASS_ID == mxINT32_CLASS)
{
int* ptr_data_tmp = static_cast<int*> (mxGetData(prhs[2]));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if (V_CLASS_ID == mxUINT32_CLASS)
{
unsigned int* ptr_data_tmp = static_cast<unsigned int*> (mxGetData(prhs[2]));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if (V_CLASS_ID == mxINT64_CLASS)
{
long long* ptr_data_tmp = static_cast<long long*> (mxGetData(prhs[2]));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if (V_CLASS_ID == mxUINT64_CLASS)
{
unsigned long long* ptr_data_tmp = static_cast<unsigned long long*> (mxGetData(prhs[2]));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else
mexErrMsgTxt("Unknown matlab type.");
// Si prhs[2] est un vecteur
if(nbColA == 1)
{
Faust::Vect<FFPP,Cpu> A(nbRowA, ptr_data);
Faust::Vect<FFPP,Cpu> B(nbRowB);
//NB
//B = (*core_ptr)*A;
B = (*core_ptr).multiply(A,op);
const mwSize dims[2]={nbRowB,nbColB};
if(sizeof(FFPP)==sizeof(float))
plhs[0] = mxCreateNumericArray(2, dims, mxSINGLE_CLASS, mxREAL);
else if(sizeof(FFPP)==sizeof(double))
plhs[0] = mxCreateNumericArray(2, dims, mxDOUBLE_CLASS, mxREAL);
FFPP* ptr_data = NULL;
const mxClassID V_CLASS_ID = mxGetClassID(inMatlabMatrix);
const size_t NB_ELEMENTS = mxGetNumberOfElements(inMatlabMatrix);
if(V_CLASS_ID == mxDOUBLE_CLASS)
{
double* ptr_data_tmp = static_cast<double*> (mxGetData(inMatlabMatrix));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if(V_CLASS_ID == mxSINGLE_CLASS)
{
float* ptr_data_tmp = static_cast<float*> (mxGetData(inMatlabMatrix));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if(V_CLASS_ID == mxINT8_CLASS)
{
char* ptr_data_tmp = static_cast<char*> (mxGetData(inMatlabMatrix));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if(V_CLASS_ID == mxUINT8_CLASS)
{
unsigned char* ptr_data_tmp = static_cast<unsigned char*> (mxGetData(inMatlabMatrix));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if(V_CLASS_ID == mxINT16_CLASS)
{
short* ptr_data_tmp = static_cast<short*> (mxGetData(inMatlabMatrix));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if (V_CLASS_ID == mxUINT16_CLASS)
{
unsigned short* ptr_data_tmp = static_cast<unsigned short*> (mxGetData(inMatlabMatrix));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if (V_CLASS_ID == mxINT32_CLASS)
{
int* ptr_data_tmp = static_cast<int*> (mxGetData(inMatlabMatrix));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if (V_CLASS_ID == mxUINT32_CLASS)
{
unsigned int* ptr_data_tmp = static_cast<unsigned int*> (mxGetData(inMatlabMatrix));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if (V_CLASS_ID == mxINT64_CLASS)
{
long long* ptr_data_tmp = static_cast<long long*> (mxGetData(inMatlabMatrix));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else if (V_CLASS_ID == mxUINT64_CLASS)
{
unsigned long long* ptr_data_tmp = static_cast<unsigned long long*> (mxGetData(inMatlabMatrix));
ptr_data = new FFPP[NB_ELEMENTS];
for (size_t i =0 ; i<NB_ELEMENTS ; i++)
ptr_data[i] = static_cast<FFPP> (ptr_data_tmp[i]);
}
else
mexErrMsgTxt("FFPP type is neither double nor float");
mexErrMsgTxt("Unknown matlab type.");
FFPP* ptr_out = static_cast<FFPP*> (mxGetData(plhs[0]));
memcpy(ptr_out, B.getData(), nbRowB*nbColB*sizeof(FFPP));
}
// Si prhs[2] est une matrice
else
{
Faust::MatDense<FFPP,Cpu> A(ptr_data, nbRowA, nbColA);
Faust::MatDense<FFPP,Cpu> B(nbRowB, nbColA);
// Si inMatlabMatrix est un vecteur
if(nbColA == 1)
{
Faust::Vect<FFPP,Cpu> A(nbRowA, ptr_data);
Faust::Vect<FFPP,Cpu> B(nbRowB);
//NB
//B = (*core_ptr)*A;
B = (*core_ptr).multiply(A,op);
const mwSize dims[2]={nbRowB,nbColB};
if(sizeof(FFPP)==sizeof(float))
plhs[0] = mxCreateNumericArray(2, dims, mxSINGLE_CLASS, mxREAL);
else if(sizeof(FFPP)==sizeof(double))
plhs[0] = mxCreateNumericArray(2, dims, mxDOUBLE_CLASS, mxREAL);
const mwSize dims[2]={nbRowB,nbColB};
if(sizeof(FFPP)==sizeof(float))
plhs[0] = mxCreateNumericArray(2, dims, mxSINGLE_CLASS, mxREAL);
else if(sizeof(FFPP)==sizeof(double))
plhs[0] = mxCreateNumericArray(2, dims, mxDOUBLE_CLASS, mxREAL);
else
mexErrMsgTxt("FFPP type is neither double nor float");
FFPP* ptr_out = static_cast<FFPP*> (mxGetData(plhs[0]));
memcpy(ptr_out, B.getData(), nbRowB*nbColB*sizeof(FFPP));
}
// Si inMatlabMatrix est une matrice
else
mexErrMsgTxt("FFPP type is neither double nor float");
FFPP* ptr_out = static_cast<FFPP*> (mxGetData(plhs[0]));
memcpy(ptr_out, B.getData(), nbRowB*nbColB*sizeof(FFPP));
}
if(ptr_data) {delete [] ptr_data ; ptr_data = NULL;}
{
Faust::MatDense<FFPP,Cpu> A(ptr_data, nbRowA, nbColA);
Faust::MatDense<FFPP,Cpu> B(nbRowB, nbColA);
B = (*core_ptr).multiply(A,op);
const mwSize dims[2]={nbRowB,nbColB};
if(sizeof(FFPP)==sizeof(float))
plhs[0] = mxCreateNumericArray(2, dims, mxSINGLE_CLASS, mxREAL);
else if(sizeof(FFPP)==sizeof(double))
plhs[0] = mxCreateNumericArray(2, dims, mxDOUBLE_CLASS, mxREAL);
else
mexErrMsgTxt("FFPP type is neither double nor float");
FFPP* ptr_out = static_cast<FFPP*> (mxGetData(plhs[0]));
memcpy(ptr_out, B.getData(), nbRowB*nbColB*sizeof(FFPP));
}
if(ptr_data) {delete [] ptr_data ; ptr_data = NULL;}
return;
return;
}
......
......@@ -44,6 +44,7 @@
#include <vector>
#include "faust_constant.h"
namespace Faust {
template<typename FPP, Device DEVICE> class ConstraintGeneric;
template<typename FPP, Device DEVICE> class Vect;
......@@ -51,8 +52,22 @@ namespace Faust {
template<typename FPP, Device DEVICE> class MatGeneric;
template<typename FPP, Device DEVICE> class MatDense;
template<typename FPP, Device DEVICE> class MatSparse;
template<typename FPP, Device DEVICE> class Transform;
template<typename FPP, Device DEVICE> class LinearOperator;
}
/*!
* \brief check if the Faust::Transform T has compatible scalar with MATLAB matrix Matlab_Mat (currently real is only compatible with real and complex is only compatible with complex)
* \param T : Faust::Transform<FPP,Cpu>
* \tparam Matlab_Mat : mxArray pointer
*/
template<typename FPP>
bool isScalarCompatible(const Faust::LinearOperator<FPP,Cpu> & L,const mxArray * Matlab_Mat);
/*!
* \brief convert the matlab mxArray* into a Faust::Vect<FPP,Cpu>, no shared memory
* \param vec_array : pointer to the mxArray* (matlab format) representing a dense column Vector
......
......@@ -53,9 +53,21 @@
#include "faust_MatDense.h"
#include "faust_MatSparse.h"
#include "faust_Vect.h"
#include "faust_Transform.h"
#include "faust_LinearOperator.h"
template<typename FPP>
bool isScalarCompatible(Faust::LinearOperator<FPP,Cpu> & L, const mxArray * Matlab_Mat)
{
bool isMatlabComplex = mxIsComplex(Matlab_Mat);
bool isTransformComplex = !(L.isReal());
return (isMatlabComplex == isTransformComplex);
}
template<typename FPP>
void getFaustVec(const mxArray * vec_array,Faust::Vect<FPP,Cpu> & vec)
{
......@@ -134,7 +146,11 @@ void getFaustMat(const mxArray* Mat_array,Faust::MatDense<FPP,Cpu> & Mat)
{
int nbRow,nbCol;
if (mxIsEmpty(Mat_array))
{
mexErrMsgTxt("tools_mex.h:getFaustMat :input matrix is empty.");
......@@ -156,6 +172,11 @@ void getFaustMat(const mxArray* Mat_array,Faust::MatDense<FPP,Cpu> & Mat)
//mexErrMsgTxt("sparse matrix entry instead of dense matrix");
mexErrMsgIdAndTxt("a","a sparse matrix entry instead of dense matrix");
}
//check scalar compayibility
if (!isScalarCompatible(Mat,Mat_array))
mexErrMsgTxt("getFaustMat scalar type (complex/real) are not compatible");
const mxClassID V_CLASS_ID = mxGetClassID(Mat_array);
FPP* MatPtr;
if (((V_CLASS_ID == mxDOUBLE_CLASS) && (sizeof(double) == sizeof(FPP))) || ((V_CLASS_ID == mxSINGLE_CLASS) && (sizeof(float) == sizeof(FPP))))
......@@ -204,7 +225,11 @@ void getFaustspMat(const mxArray* spMat_array,Faust::MatSparse<FPP,Cpu> & S)
mexErrMsgIdAndTxt("tools_mex.h:getFaustspMat",
"input array must be sparse");
}
int nnzMax = mxGetNzmax(spMat_array);
//check scalar compayibility
if (!isScalarCompatible(S,spMat_array))
mexErrMsgTxt("getFaustspMat scalar type (complex/real) are not compatible");
int nnzMax = mxGetNzmax(spMat_array);
int nbCol = mxGetN(spMat_array);
int nbRow = mxGetM(spMat_array);
//mexPrintf("DIM (%d,%d) NNZMAX : %d\n",nbRow,nbCol,nnzMax);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment