Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 9654664b authored by hhakim's avatar hhakim
Browse files

Write the matfaust wrapper for C++ Faust::HierarchicalFactFFT and add a unit test.

The code is to be reviewed again and documented (matfaust.FaustFactory.fgft_palm()).
parent 9de703ba
Branches
Tags
No related merge requests found
......@@ -158,6 +158,43 @@ classdef FaustFactoryTest < matlab.unittest.TestCase
this.verifyEqual(norm(E,'fro')/norm(M,'fro'), 1.0063, 'AbsTol', 0.0001)
end
function test_fgft_palm(this)
disp('Test FaustFactory.fgft_palm()')
import matfaust.*
import matfaust.factparams.*
num_facts = 4;
is_update_way_R2L = false;
init_lambda = 1.0;
load([this.faust_paths{1},'../../../misc/data/mat/HierarchicalFactFFT_test_U_L_params.mat'])
% U, Lap, init_D, params are loaded from file
fact_cons = cell(3,1);
res_cons = cell(3, 1);
fact_cons{1} = ConstraintInt(ConstraintName(ConstraintName.SP), 128, 128, 12288);
fact_cons{2} = ConstraintInt(ConstraintName(ConstraintName.SP), 128, 128, 6144);
fact_cons{3} = ConstraintInt(ConstraintName(ConstraintName.SP), 128, 128, 3072);
res_cons{1} = ConstraintInt(ConstraintName(ConstraintName.SP), 128, 128, 384);
res_cons{2} = ConstraintInt(ConstraintName(ConstraintName.SP), 128, 128, 384);
res_cons{3} = ConstraintInt(ConstraintName(ConstraintName.SP), 128, 128, 384);
stop_crit = StoppingCriterion(params.niter1);
stop_crit2 = StoppingCriterion(params.niter2);
params.fact_side = 0 % forced
params.verbose = 1 % forced
params.init_lambda = 128;
params = ParamsHierarchicalFact(fact_cons, res_cons, stop_crit, stop_crit2, 'is_fact_side_left', params.fact_side == 1, 'is_update_way_R2L', params.update_way == 1, 'init_lambda', params.init_lambda, 'step_size', params.stepsize, 'constant_step_size', false, 'is_verbose', true);
diag_init_D = diag(init_D)
[F,lambda] = FaustFactory.fgft_palm(U, Lap, params, diag_init_D)
this.verifyEqual(size(F), size(U))
%disp('norm F: ')
%norm(F, 'fro')
E = full(F)-U;
err = norm(E,'fro')/norm(U,'fro')
% matrix to factorize and reference relative error come from
% misc/test/src/C++/hierarchicalFactorizationFFT.cpp
this.verifyEqual(err, 0.05539, 'AbsTol', 0.00001)
end
function testHadamard(this)
disp('Test FaustFactory.wht()')
import matfaust.*
......
......@@ -992,14 +992,14 @@ class TestFaustFactory(unittest.TestCase):
loadmat(sys.path[-1]+"/../../../misc/data/mat/HierarchicalFactFFT_test_U_L_params.mat")['init_D']
params_struct = \
loadmat(sys.path[-1]+'/../../../misc/data/mat/HierarchicalFactFFT_test_U_L_params.mat')['params']
nfacts = params_struct['nfacts'][0,0][0,0]
nfacts = params_struct['nfacts'][0,0][0,0] #useless
niter1 = params_struct['niter1'][0,0][0,0]
niter2 = params_struct['niter2'][0,0][0,0]
verbose = params_struct['verbose'][0,0][0,0]==1
is_update_way_R2L = params_struct['update_way'][0,0][0,0]==1
init_lambda = params_struct['init_lambda'][0,0][0,0]
stepsize = params_struct['stepsize'][0,0][0,0]
factside = params_struct['fact_side'][0,0][0,0] == 1
factside = params_struct['fact_side'][0,0][0,0] == 1 # ignored
# for convenience I set the constraints manually and don't take them
# from mat file, but they are the same
# default step_size
......
......@@ -16,7 +16,7 @@ namespace Faust
public:
//TODO: move def. code in .hpp
HierarchicalFactFFT(const MatDense<FPP,DEVICE>& U, const MatDense<FPP,DEVICE>& Lap, ParamsFFT<FPP,DEVICE,FPP2>& params, BlasHandle<DEVICE> cublasHandle, SpBlasHandle<DEVICE> cusparseHandle): HierarchicalFact<FPP, DEVICE, FPP2>(U, params, cublasHandle, cusparseHandle)
HierarchicalFactFFT(const MatDense<FPP,DEVICE>& U, const MatDense<FPP,DEVICE>& Lap, const ParamsFFT<FPP,DEVICE,FPP2>& params, BlasHandle<DEVICE> cublasHandle, SpBlasHandle<DEVICE> cusparseHandle): HierarchicalFact<FPP, DEVICE, FPP2>(U, params, cublasHandle, cusparseHandle)
{
if ((U.getNbRow() != params.m_nbRow) | (U.getNbCol() != params.m_nbCol))
handleError(m_className,"constructor : params and Fourier matrix U haven't compatible size");
......
......@@ -142,7 +142,7 @@ namespace Faust
void check_constraint_validity();
void check_bool_validity();
void Display() const;
virtual void Display() const;
~Params(){}
......
......@@ -13,7 +13,7 @@ namespace Faust
{
public:
MatDense<FPP, DEVICE> init_D; //TODO: convert to Sparse or Diag repres.
MatDense<FPP, DEVICE> init_D; //TODO: convert to Sparse or Diag repres. and set private or protected
//TODO: does it really need to be public
//TODO: move the ctor def into .hpp
ParamsFFT(
......@@ -55,7 +55,6 @@ namespace Faust
// set init_D from diagonal vector init_D_diag
for(int i=0;i<nbRow;i++)
init_D.getData()[i*nbRow+i] = init_D_diag.getData()[i];
}
ParamsFFT(
......@@ -81,6 +80,9 @@ namespace Faust
ParamsFFT() {}
void Display() const;
};
}
......
template<typename FPP,Device DEVICE,typename FPP2>
void Faust::ParamsFFT<FPP,DEVICE,FPP2>::Display() const
{
Faust::Params<FPP,DEVICE,FPP2>::Display();
cout << "init_D isIdentity:" << init_D.estIdentite() << endl;
cout << "init_D info:" << endl;
init_D.Display();
cout << "ParamsFFT init_D norm: " << init_D.norm() << endl;
}
......@@ -262,7 +262,7 @@ classdef FaustFactory
error('M''s number of columns must be consistent with the last residuum constraint defined in p. Likewise its number of rows must be consistent with the first factor constraint defined in p.')
end
% the setters for num_rows/cols verifies consistency with constraints
mex_params = struct('nfacts', p.num_facts, 'cons', {mex_constraints}, 'niter1', p.stop_crits{1}.num_its,'niter2', p.stop_crits{2}.num_its, 'sc_is_criterion_error', p.stop_crits{1}.is_criterion_error, 'sc_error_treshold', p.stop_crits{1}.error_treshold, 'sc_max_num_its', p.stop_crits{1}.max_num_its, 'sc_is_criterion_error2', p.stop_crits{2}.is_criterion_error, 'sc_error_treshold2', p.stop_crits{2}.error_treshold, 'sc_max_num_its2', p.stop_crits{2}.max_num_its, 'nrow', p.data_num_rows, 'ncol', p.data_num_cols, 'fact_side', p.is_fact_side_left, 'update_way', p.is_update_way_R2L);
mex_params = struct('nfacts', p.num_facts, 'cons', {mex_constraints}, 'niter1', p.stop_crits{1}.num_its,'niter2', p.stop_crits{2}.num_its, 'sc_is_criterion_error', p.stop_crits{1}.is_criterion_error, 'sc_error_treshold', p.stop_crits{1}.error_treshold, 'sc_max_num_its', p.stop_crits{1}.max_num_its, 'sc_is_criterion_error2', p.stop_crits{2}.is_criterion_error, 'sc_error_treshold2', p.stop_crits{2}.error_treshold, 'sc_max_num_its2', p.stop_crits{2}.max_num_its, 'nrow', p.data_num_rows, 'ncol', p.data_num_cols, 'fact_side', p.is_fact_side_left, 'update_way', p.is_update_way_R2L, 'verbose', p.is_verbose, 'init_lambda', p.init_lambda);
if(isreal(M))
[lambda, core_obj] = mexHierarchical_factReal(M, mex_params);
else
......@@ -272,6 +272,66 @@ classdef FaustFactory
varargout = {F, lambda, p};
end
function varargout = fgft_palm(U, Lap, p, varargin)
import matfaust.Faust
import matfaust.factparams.*
% TODO: check U, Lap sizes, same field
% TODO: refactor with fact_hierarchical
if(length(varargin) == 1)
init_D = varargin{1};
if(~ ismatrix(init_D) || ~ isnumeric(init_D))
error('fgft_palm arg. 4 must be a matrix')
end
elseif(length(varargin) > 1)
error('fgft_palm, too many arguments.')
else % nargin == 0
init_D = ones(size(U,1));
if(~ isreal(U))
init_D = complex(init_D);
end
end
matfaust.FaustFactory.check_fact_mat('FaustFactory.fgft_palm', U)
if(~ isa(p, 'ParamsHierarchicalFact') && ParamsFactFactory.is_a_valid_simplification(p))
p = ParamsFactFactory.createParams(U, p);
end
mex_constraints = cell(2, p.num_facts-1);
if(~ isa(p ,'ParamsHierarchicalFact'))
error('p must be a ParamsHierarchicalFact object.')
end
%mex_fact_constraints = cell(1, p.num_facts-1)
for i=1:p.num_facts-1
cur_cell = cell(1, 4);
cur_cell{1} = p.constraints{i}.name.conv2str();
cur_cell{2} = p.constraints{i}.param;
cur_cell{3} = p.constraints{i}.num_rows;
cur_cell{4} = p.constraints{i}.num_cols;
%mex_fact_constraints{i} = cur_cell;
mex_constraints{1,i} = cur_cell;
end
%mex_residuum_constraints = cell(1, p.num_facts-1)
for i=1:p.num_facts-1
cur_cell = cell(1, 4);
cur_cell{1} = p.constraints{i+p.num_facts-1}.name.conv2str();
cur_cell{2} = p.constraints{i+p.num_facts-1}.param;
cur_cell{3} = p.constraints{i+p.num_facts-1}.num_rows;
cur_cell{4} = p.constraints{i+p.num_facts-1}.num_cols;
%mex_residuum_constraints{i} = cur_cell;
mex_constraints{2,i} = cur_cell;
end
if(~ p.is_mat_consistent(U))
error('U''s number of columns must be consistent with the last residuum constraint defined in p. Likewise its number of rows must be consistent with the first factor constraint defined in p.')
end
% the setters for num_rows/cols verifies consistency with constraints
mex_params = struct('nfacts', p.num_facts, 'cons', {mex_constraints}, 'niter1', p.stop_crits{1}.num_its,'niter2', p.stop_crits{2}.num_its, 'sc_is_criterion_error', p.stop_crits{1}.is_criterion_error, 'sc_error_treshold', p.stop_crits{1}.error_treshold, 'sc_max_num_its', p.stop_crits{1}.max_num_its, 'sc_is_criterion_error2', p.stop_crits{2}.is_criterion_error, 'sc_error_treshold2', p.stop_crits{2}.error_treshold, 'sc_max_num_its2', p.stop_crits{2}.max_num_its, 'nrow', p.data_num_rows, 'ncol', p.data_num_cols, 'fact_side', p.is_fact_side_left, 'update_way', p.is_update_way_R2L, 'init_D', init_D, 'verbose', p.is_verbose, 'init_lambda', p.init_lambda);
if(isreal(U))
[lambda, core_obj] = mexHierarchical_factReal(U, mex_params, Lap);
else
[lambda, core_obj] = mexHierarchical_factCplx(U, mex_params, Lap);
end
F = Faust(core_obj, isreal(U));
varargout = {F, lambda, p};
end
%==========================================================================================
%> @brief Constructs a Faust implementing the Walsh-Hadamard Transform of order 2^n.
%>
......
......@@ -40,6 +40,7 @@
#include "mex.h"
#include "faust_HierarchicalFact.h"
#include "faust_HierarchicalFactFFT.h"
#include "faust_TransformHelper.h"
#include "class_handle.hpp"
#include <vector>
......@@ -68,12 +69,16 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
system("sleep 7");
#endif
if (nrhs != 2)
{
bool is_fgft = false;
if(nrhs == 3)
is_fgft = true;
else if(nrhs == 2)
is_fgft = false;
else
mexErrMsgTxt("Bad Number of inputs arguments");
}
const mxArray* matlab_matrix = prhs[0];
const mxArray* matlab_params = prhs[1];
const mxArray* matlab_matrix = prhs[0];
const mxArray* matlab_params = prhs[1];
if(!mxIsStruct(matlab_params))
{
mexErrMsgTxt("Input must be a structure.");
......@@ -86,7 +91,6 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
auto *params = mxArray2FaustParams<SCALAR,FPP2>(matlab_params);
///////////// HIERARCHICAL LAUNCH ///////////////
// creation des parametres
try{
// std::cout<<"nb_row : "<<nb_row<<std::endl;
// std::cout<<"nb_col : "<<nb_col<<std::endl;
......@@ -99,7 +103,19 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Faust::HierarchicalFact<SCALAR,Cpu, FPP2>* hier_fact;
hier_fact = new Faust::HierarchicalFact<SCALAR,Cpu,FPP2>(matrix,*params,blas_handle,spblas_handle);
if(dynamic_cast<const ParamsFFT<SCALAR,Cpu,FPP2>*>(params))
{
if(! is_fgft) mexErrMsgTxt("Bad Number of inputs arguments for PALM FGFT.");
Faust::MatDense<SCALAR,Cpu> laplacian_mat;
const mxArray* matlab_lap_mat = prhs[2];
mxArray2FaustMat(matlab_lap_mat,laplacian_mat);
params->Display();
hier_fact = new Faust::HierarchicalFactFFT<SCALAR,Cpu,FPP2>(matrix,laplacian_mat,*dynamic_cast<const ParamsFFT<SCALAR,Cpu,FPP2>*>(params),blas_handle,spblas_handle);
}
else
hier_fact = new Faust::HierarchicalFact<SCALAR,Cpu,FPP2>(matrix,*params,blas_handle,spblas_handle);
hier_fact->compute_facts();
......
......@@ -101,6 +101,8 @@ const string mat_field_type2str(MAT_FIELD_TYPE f)
return "sc_max_num_its2";
case INIT_FACTS:
return "init_facts";
case INIT_D:
return "init_D";
}
}
......@@ -124,21 +126,8 @@ void testCoherence(const mxArray* params,std::vector<bool> & presentFields)
presentFields.resize(MAT_FIELD_TYPE_LEN);
presentFields.assign(MAT_FIELD_TYPE_LEN,false);
//TODO: this function should be modified to be more reliable
// the consistency between the matlab structure
// and the STL vector is not enough reliable and is prone to errors when extending structure with new fields
// maybe a structure in C++ should be used instead of a vector to avoid index overflow situations
// (that kind of bugs already happened)
// the structure would be a C++ equivalent of the matlab struct with default values.
// Fields would be retrieved by name not by index
// By the way it would be an optimization from the caller pt of view not having to call mxGetField
// (and neither to do a bunch of strcmp-s for each field of the structure)
// In brief, consistency check + fields retrival in one time + error raising when *
// unknown fields found or mandatory fields not found
// An equivalent function is located in mexPalm4MSA (testCoherencePalm4MSA()) and also applies to this modification
if(nbr_field < 3)
mexErrMsgTxt("The number of field of params must be at least 3 ");
mexErrMsgTxt("The number of fields in params must be at least 3 ");
for(int i=0;i<nbr_field;i++)
presentFields[mat_field_str2type(string(mxGetFieldNameByNumber(params,i)))] = true;
......
......@@ -171,12 +171,13 @@ enum MAT_FIELD_TYPE
SC_ERROR_TRESHOLD2,
SC_MAX_NUM_ITS2,
INIT_FACTS,
};
INIT_D //only for FactHierarchicalF(G)FT
};// if you wan to add a field, dont forget to update mat_field_type2str() and MAT_FIELD_TYPE_LEN
const string mat_field_type2str(MAT_FIELD_TYPE f);
const MAT_FIELD_TYPE mat_field_str2type(const string& fstr);
const unsigned int MAT_FIELD_TYPE_LEN = 17;
const unsigned int MAT_FIELD_TYPE_LEN = 18; // must be the number of fields in MAT_FIELD_TYPE
void testCoherence(const mxArray* params,std::vector<bool> & presentFields);
......
......@@ -50,6 +50,7 @@
#include "faust_ConstraintMat.h"
#include "faust_ConstraintInt.h"
#include "faust_Params.h"
#include "faust_ParamsFFT.h"
#include "faust_MatDense.h"
#include "faust_MatSparse.h"
#include "faust_Vect.h"
......@@ -756,13 +757,27 @@ const Params<SCALAR, Cpu, FPP2>* mxArray2FaustParams(const mxArray* matlab_param
if (presentFields[INIT_LAMBDA])
{
mxCurrentField = mxGetField(matlab_params,0,mat_field_type2str(INIT_LAMBDA).c_str());
SCALAR* tmp_ptr = &init_lambda;
// SCALAR* tmp_ptr = &init_lambda;
// it works whatever mxCurrentField class is (complex or not)
mxArray2Ptr<SCALAR>(const_cast<const mxArray*>(mxCurrentField), tmp_ptr);
// mxArray2Ptr<SCALAR>(const_cast<const mxArray*>(mxCurrentField), tmp_ptr);
// init_lambda = (SCALAR) mxGetScalar(mxCurrentField);
init_lambda = (SCALAR) mxGetScalar(mxCurrentField);
}
Faust::Params<SCALAR,Cpu,FPP2>* params = new Params<SCALAR,Cpu,FPP2>(nb_row,nb_col,nbFact,consSS,/*std::vector<Faust::MatDense<SCALAR,Cpu> >()*/ init_facts,crit1,crit2,isVerbose,updateway,factside,init_lambda);
Faust::Params<SCALAR,Cpu,FPP2>* params;
if(presentFields[INIT_D])
{
//get the diagonal vector to define the init_D matrix (cf. FactHierarchicalF(G)FT
SCALAR* init_D = new SCALAR[nb_row]; //nb_col == nb_row when using FactHierarchicalF(G)FT
mxCurrentField = mxGetField(matlab_params,0,mat_field_type2str(INIT_D).c_str());
mxArray2Ptr<SCALAR>(const_cast<const mxArray*>(mxCurrentField), init_D);
params = new ParamsFFT<SCALAR,Cpu,FPP2>(nb_row,nb_col,nbFact,consSS, init_facts, init_D, crit1,crit2,isVerbose,updateway,factside,init_lambda);
delete init_D;
}
else
{
params = new Params<SCALAR,Cpu,FPP2>(nb_row,nb_col,nbFact,consSS,/*std::vector<Faust::MatDense<SCALAR,Cpu> >()*/ init_facts,crit1,crit2,isVerbose,updateway,factside,init_lambda);
}
return params;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment