Mentions légales du service

Skip to content
Snippets Groups Projects
Commit d1a429b2 authored by Nicolas Bellot's avatar Nicolas Bellot Committed by hhakim
Browse files

wrapper matlab : boolean passer en parametre de mtimes_trans

parent fca4275e
No related branches found
No related tags found
No related merge requests found
...@@ -115,8 +115,8 @@ for i=1:nb_approx_MEG ...@@ -115,8 +115,8 @@ for i=1:nb_approx_MEG
matlab_faustS_mult{i}=@(x) f_mult(sp_facts,x); % function handle matlab_faustS_mult{i}=@(x) f_mult(sp_facts,x); % function handle
% wrapper C++ faust % wrapper C++ faust
trans_faustS_mult{i}=@(x) mtimes_trans(fc,x,'T');% function handle trans_faustS_mult{i}=@(x) mtimes_trans(fc,x,1);% function handle
faustS_mult{i}=@(x) mtimes_trans(fc,x,'N'); % function handle faustS_mult{i}=@(x) mtimes_trans(fc,x,0); % function handle
MEG_faustS{i}=fc; % store the different faust approximations MEG_faustS{i}=fc; % store the different faust approximations
end end
M=size(X_norm,2); M=size(X_norm,2);
......
...@@ -113,7 +113,7 @@ for i=1:nb_mult ...@@ -113,7 +113,7 @@ for i=1:nb_mult
t4=toc; t4=toc;
tic tic
yfaust_mtimes_trans=mtimes_trans(hadamard_faust,x,'T'); yfaust_mtimes_trans=mtimes_trans(hadamard_faust,x,1);
t5=toc; t5=toc;
tic tic
......
...@@ -167,7 +167,7 @@ for i=1:Nb_mult+1 ...@@ -167,7 +167,7 @@ for i=1:Nb_mult+1
tfaust=toc; tfaust=toc;
t_faust(i,j,k,l,1)=tfaust; t_faust(i,j,k,l,1)=tfaust;
tic; tic;
yfaust_trans=mtimes_trans(faust_transform,x,'T'); yfaust_trans=mtimes_trans(faust_transform,x,1);
tfaust_trans=toc; tfaust_trans=toc;
t_faust(i,j,k,l,2)=tfaust_trans; t_faust(i,j,k,l,2)=tfaust_trans;
end end
......
...@@ -296,7 +296,8 @@ x=zeros(dim2,1); ...@@ -296,7 +296,8 @@ x=zeros(dim2,1);
x(:)=1:dim2; x(:)=1:dim2;
x_trans=zeros(dim1,1); x_trans=zeros(dim1,1);
x_trans(:)=1:dim1; x_trans(:)=1:dim1;
istransposed=1;
nontransposed=0;
y_expected = F_dense*x; y_expected = F_dense*x;
y_expected_trans = F_dense'*x_trans; y_expected_trans = F_dense'*x_trans;
...@@ -313,25 +314,25 @@ if (y_expected_trans~= y_star_trans) ...@@ -313,25 +314,25 @@ if (y_expected_trans~= y_star_trans)
end end
y_mtimes_trans = mtimes_trans(F,x_trans,'T'); y_mtimes_trans = mtimes_trans(F,x_trans,istransposed);
if (y_expected_trans ~= y_mtimes_trans) if (y_expected_trans ~= y_mtimes_trans)
error(['multiplication faust-vector with transposition : invalid result within the precision ' num2str(threshold)]); error(['multiplication faust-vector with transposition : invalid result within the precision ' num2str(threshold)]);
end end
y_mtimes = mtimes_trans(F,x,'N'); y_mtimes = mtimes_trans(F,x,nontransposed);
if (y_expected ~= y_mtimes) if (y_expected ~= y_mtimes)
error(['multiplication faust-vector : invalid result within the precision ' num2str(threshold)]); error(['multiplication faust-vector : invalid result within the precision ' num2str(threshold)]);
end end
y_mtimes_trans_N = mtimes_trans(F_trans,x_trans,'N'); y_mtimes_trans_N = mtimes_trans(F_trans,x_trans,nontransposed);
if (y_expected_trans ~= y_mtimes_trans_N) if (y_expected_trans ~= y_mtimes_trans_N)
error(['multiplication faust-vector with transposition : invalid result within the precision ' num2str(threshold)]); error(['multiplication faust-vector with transposition : invalid result within the precision ' num2str(threshold)]);
end end
y_mtimes_trans_T = mtimes_trans(F_trans,x,'T'); y_mtimes_trans_T = mtimes_trans(F_trans,x,istransposed);
if (y_expected ~= y_mtimes_trans_T) if (y_expected ~= y_mtimes_trans_T)
error(['multiplication faust-vector : invalid result within the precision ' num2str(threshold)]); error(['multiplication faust-vector : invalid result within the precision ' num2str(threshold)]);
end end
...@@ -376,24 +377,24 @@ if (Y_expected_trans~= Y_star_trans) ...@@ -376,24 +377,24 @@ if (Y_expected_trans~= Y_star_trans)
end end
Y_mtimes_trans = mtimes_trans(F,X_trans,'T'); Y_mtimes_trans = mtimes_trans(F,X_trans,istransposed);
if (Y_expected_trans ~= Y_mtimes_trans) if (Y_expected_trans ~= Y_mtimes_trans)
error(['multiplication faust-vector with transposition : invalid result within the precision ' num2str(threshold)]); error(['multiplication faust-vector with transposition : invalid result within the precision ' num2str(threshold)]);
end end
Y_mtimes = mtimes_trans(F,X,'N'); Y_mtimes = mtimes_trans(F,X,nontransposed);
if (Y_expected ~= Y_mtimes) if (Y_expected ~= Y_mtimes)
error(['multiplication faust-vector : invalid result within the precision ' num2str(threshold)]); error(['multiplication faust-vector : invalid result within the precision ' num2str(threshold)]);
end end
Y_mtimes_trans_N = mtimes_trans(F_trans,X_trans,'N'); Y_mtimes_trans_N = mtimes_trans(F_trans,X_trans,nontransposed);
if (Y_expected_trans ~= Y_mtimes_trans_N) if (Y_expected_trans ~= Y_mtimes_trans_N)
error(['multiplication faust-vector with transposition : invalid result within the precision ' num2str(threshold)]); error(['multiplication faust-vector with transposition : invalid result within the precision ' num2str(threshold)]);
end end
Y_mtimes_trans_T = mtimes_trans(F_trans,X,'T'); Y_mtimes_trans_T = mtimes_trans(F_trans,X,istransposed);
if (y_expected ~= y_mtimes_trans_T) if (y_expected ~= y_mtimes_trans_T)
error(['multiplication faust-vector : invalid result within the precision ' num2str(threshold)]); error(['multiplication faust-vector : invalid result within the precision ' num2str(threshold)]);
end end
......
...@@ -125,6 +125,8 @@ disp(['tps F_trans(:,1) : ' num2str(mean(t_access_trans_col))]); ...@@ -125,6 +125,8 @@ disp(['tps F_trans(:,1) : ' num2str(mean(t_access_trans_col))]);
%% test faust multiplication with vector %% test faust multiplication with vector
disp('TEST MULTIPLICATION BY A VECTOR : '); disp('TEST MULTIPLICATION BY A VECTOR : ');
istransposed=1;
nontransposed=0;
x=zeros(dim2,1); x=zeros(dim2,1);
x(:)=1:dim2; x(:)=1:dim2;
x_trans=zeros(dim1,1); x_trans=zeros(dim1,1);
...@@ -164,9 +166,9 @@ for i=1:nb_multiplication_vector ...@@ -164,9 +166,9 @@ for i=1:nb_multiplication_vector
error(['multiplication faust-vector with transposition : invalid result within the precision ' num2str(threshold)]); error(['multiplication faust-vector with transposition : invalid result within the precision ' num2str(threshold)]);
end end
%% mtimes_trans(F,x_trans,'T'); %% mtimes_trans(F,x_trans,istransposed);
tic tic
y_mtimes_trans = mtimes_trans(F,x_trans,'T'); y_mtimes_trans = mtimes_trans(F,x_trans,istransposed);
t_mtimes_trans(i) = toc; t_mtimes_trans(i) = toc;
if (y_expected_trans ~= y_mtimes_trans) if (y_expected_trans ~= y_mtimes_trans)
...@@ -175,9 +177,9 @@ for i=1:nb_multiplication_vector ...@@ -175,9 +177,9 @@ for i=1:nb_multiplication_vector
%% mtimes_trans(F,x,'N'); %% mtimes_trans(F,x,nontransposed);
tic tic
y_mtimes = mtimes_trans(F,x,'N'); y_mtimes = mtimes_trans(F,x,nontransposed);
t_mtimes(i) = toc; t_mtimes(i) = toc;
if (y_expected ~= y_mtimes) if (y_expected ~= y_mtimes)
error(['multiplication faust-vector : invalid result within the precision ' num2str(threshold)]); error(['multiplication faust-vector : invalid result within the precision ' num2str(threshold)]);
...@@ -190,8 +192,8 @@ end ...@@ -190,8 +192,8 @@ end
disp(['tps A=A'' : ' num2str(mean(t_trans))]); disp(['tps A=A'' : ' num2str(mean(t_trans))]);
disp(['tps A*x : ' num2str(mean(t_times))]); disp(['tps A*x : ' num2str(mean(t_times))]);
disp(['tps A''*x : ' num2str(mean(t_trans_times))]); disp(['tps A''*x : ' num2str(mean(t_trans_times))]);
disp(['tps mtimes_trans(F,x,''N'') : ' num2str(mean(t_mtimes))]); disp(['tps mtimes_trans(F,x,nontransposed) : ' num2str(mean(t_mtimes))]);
disp(['tps mtimes_trans(F,x_trans,''T'') : ' num2str(mean(t_mtimes_trans))]); disp(['tps mtimes_trans(F,x_trans,istransposed) : ' num2str(mean(t_mtimes_trans))]);
......
...@@ -147,8 +147,14 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) ...@@ -147,8 +147,14 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
if(core_ptr->size() == 0) if(core_ptr->size() == 0)
mexErrMsgTxt("get_product : empty faust core"); mexErrMsgTxt("get_product : empty faust core");
mxChar * char_array=mxGetChars(prhs[2]); bool transpose_flag = (bool) mxGetScalar(prhs[2]);
char op=char_array[0];
char op;
if (transpose_flag)
op='T';
else
op='N';
faust_unsigned_int nbRowOp,nbColOp; faust_unsigned_int nbRowOp,nbColOp;
(*core_ptr).setOp(op,nbRowOp,nbColOp); (*core_ptr).setOp(op,nbRowOp,nbColOp);
const size_t SIZE_B1 = nbRowOp; const size_t SIZE_B1 = nbRowOp;
...@@ -206,8 +212,15 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) ...@@ -206,8 +212,15 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
if (nelem != 1) if (nelem != 1)
mexErrMsgTxt("invalid char argument."); mexErrMsgTxt("invalid char argument.");
mxChar * char_array=mxGetChars(prhs[3]); // boolean flag to know if the faust si transposed
char op=char_array[0]; bool transpose_flag = (bool) mxGetScalar(prhs[3]);
char op;
if (transpose_flag)
op='T';
else
op='N';
const size_t nbRowA = mxGetM(prhs[2]); const size_t nbRowA = mxGetM(prhs[2]);
const size_t nbColA = mxGetN(prhs[2]); const size_t nbColA = mxGetN(prhs[2]);
......
...@@ -36,40 +36,31 @@ classdef Faust ...@@ -36,40 +36,31 @@ classdef Faust
%% Multiplication faust-vector or faust-matrix %% Multiplication faust-vector or faust-matrix
function y = mtimes(this,x) function y = mtimes(this,x)
% mtimes - overloading of the matlab multiplication (*) function, compatible with matlab matrix and vector % mtimes - overloading of the matlab multiplication (*) function, compatible with matlab matrix and vector
if (this.transpose_flag) y = mexFaust('multiply', this.matrix.objectHandle,x,this.transpose_flag);
trans='T';
else
trans='N';
end
y = mexFaust('multiply', this.matrix.objectHandle,x,trans);
end end
%% Multiplication by a faust or its transpose %% Multiplication by a faust or its transpose
% if trans = 'N' multiplication by faust % if trans = 0 multiplication by faust
% if trans = 'T' multiplication the transpose of a faust % if trans = 1 multiplication by the transpose of a faust
function y = mtimes_trans(this,x,trans) function y = mtimes_trans(this,x,trans)
if ~isreal(trans)
if xor(strcmp(trans,'T'),this.transpose_flag) error('invalid argument trans, must be equal to 0 or 1');
trans='T'; end
else
trans='N'; if (trans ~= 1) && (trans ~= 0)
end error('invalid argument trans, must be equal to 0 or 1');
end
y = mexFaust('multiply', this.matrix.objectHandle,x,trans);
isreally_trans=xor(trans,this.transpose_flag);
y = mexFaust('multiply', this.matrix.objectHandle,x,isreally_trans);
end end
%% Evaluate the product of a faust_core %% Evaluate the product of a faust_core
function y = get_product(this) function y = get_product(this)
% get_product - compute the dense matrix equivalent to the faust (the product of sparse matrix) % get_product - compute the dense matrix equivalent to the faust (the product of sparse matrix)
if this.transpose_flag y=mexFaust('get_product',this.matrix.objectHandle,this.transpose_flag);
trans='T';
else
trans='N';
end
y=mexFaust('get_product',this.matrix.objectHandle,trans);
end end
...@@ -84,10 +75,11 @@ classdef Faust ...@@ -84,10 +75,11 @@ classdef Faust
function trans=ctranspose(this) function trans=ctranspose(this)
%ctranspose - overloading of the matlab transpose operator (') %ctranspose - overloading of the matlab transpose operator (')
trans=this; % trans and this point share the same C++ underlying object (objectHandle) trans=this; % trans and this point share the same C++ underlying object (objectHandle)
trans.transpose_flag = mod(this.transpose_flag+1,2); % inverse the transpose flag trans.transpose_flag = xor(1,this.transpose_flag); % inverse the transpose flag
end end
%% Size %% Size
function varargout = size(this,varargin) function varargout = size(this,varargin)
%size - overload of the matlab size function %size - overload of the matlab size function
...@@ -230,7 +222,7 @@ classdef Faust ...@@ -230,7 +222,7 @@ classdef Faust
transpose_evaluation = (nb_col_selected > nb_row_selected); transpose_evaluation = (nb_col_selected > nb_row_selected);
if transpose_evaluation if transpose_evaluation
identity=eye(dim1); identity=eye(dim1);
transpose_flag='T'; transpose_flag=1;
% switch the 2 different slicing % switch the 2 different slicing
tmp=slicing_row; tmp=slicing_row;
...@@ -239,7 +231,7 @@ classdef Faust ...@@ -239,7 +231,7 @@ classdef Faust
else else
identity=eye(dim2); identity=eye(dim2);
transpose_flag='N'; transpose_flag=0;
end end
% selects the column of the identity, if slicing_col is a char, all % selects the column of the identity, if slicing_col is a char, all
...@@ -262,20 +254,6 @@ classdef Faust ...@@ -262,20 +254,6 @@ classdef Faust
submatrix=submatrix'; submatrix=submatrix';
end end
%% former way not optimized to get access to the row
% nbcol=size(this,2);
% identity=eye(nbcol);
%
% if ~ischar(slicing_col)
% identity=identity(:,slicing_col);
% end
%
% submatrix=this*identity;
%
% if ~ischar(slicing_row)
% submatrix=submatrix(slicing_row,:);
% end
end end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment