Mentions légales du service

Skip to content
Snippets Groups Projects
Commit c67d1b28 authored by hhakim's avatar hhakim
Browse files

Adjust matfaust (including mex funcs) to support blockdia proj or constraint.

parent 68c2e95d
No related branches found
No related tags found
No related merge requests found
......@@ -59,7 +59,7 @@ bool is_constraint_name_int(const char * type)
is_const_int = ((is_const_int) || ((strcmp(type,"splin") == 0)));
is_const_int = ((is_const_int) || ((strcmp(type,"splincol") == 0)));
is_const_int = ((is_const_int) || ((strcmp(type,"splin") == 0)));
is_const_int = ((is_const_int) || ((strcmp(type,"blkdiag") == 0)));
// is_const_int = ((is_const_int) || ((strcmp(type,"blkdiag") == 0)));
return is_const_int;
......@@ -73,7 +73,7 @@ bool is_constraint_name_real(const char * type)
bool is_constraint_name_mat(const char * type)
{
return ((strcmp(type,"supp") == 0) || (strcmp(type,"const")==0));
return (strcmp(type,"supp") == 0) || (strcmp(type,"const")==0) || ! strcmp(type, "toeplitz") || ! strcmp(type, "circ") || ! strcmp(type, "blockdiag") || !strcmp(type, "blkdiag") || ! strcmp (type, "hankel");
}
faust_constraint_name get_equivalent_constraint(const char * type)
......@@ -93,7 +93,7 @@ faust_constraint_name get_equivalent_constraint(const char * type)
return CONSTRAINT_NAME_CONST;
if (!strcmp(type,"sppos"))
return CONSTRAINT_NAME_SP_POS;
if (!strcmp(type,"blkdiag"))
if (!strcmp(type,"blkdiag") || !strcmp(type, "blockdiag"))
return CONSTRAINT_NAME_BLKDIAG;
if (!strcmp(type,"supp"))
return CONSTRAINT_NAME_SUPP;
......@@ -111,7 +111,6 @@ int get_type_constraint(const char * type)
bool is_const_real = is_constraint_name_real(type);
bool is_const_mat = is_constraint_name_mat(type);
if (is_const_int)
return 0;
if (is_const_real)
......
......@@ -12,13 +12,19 @@ classdef ConstraintMat < matfaust.factparams.ConstraintGeneric
error('ConstraintMat must receive a matrix as param argument.')
end
if(islogical(param))
param = real(param)
param = real(param);
end
constraint = constraint@matfaust.factparams.ConstraintGeneric(name, size(param, 1), size(param, 2), param, varargin{:});
if(isa(name, 'matfaust.factparams.ConstraintName') && name.name == matfaust.factparams.ConstraintName.BLKDIAG || isstr(name) && matfaust.factparams.ConstraintName.str2name_int(name) == matfaust.factparams.ConstraintName.BLKDIAG)
nrows = param(end,1);
ncols = param(end,2);
else
nrows = size(param, 1);
ncols = size(param, 2);
end
constraint = constraint@matfaust.factparams.ConstraintGeneric(name, nrows, ncols, param, varargin{:});
if(~ isa(constraint.name, 'matfaust.factparams.ConstraintName') || ~ constraint.name.is_mat_constraint())
error('ConstraintMat first argument must be a ConstraintName with a matrix type name.')
end
constraint.name.name
if(constraint.default_normalized && constraint.name.name == matfaust.factparams.ConstraintName.CONST)
% for CONST proj the default is to not normalize
constraint.normalized = false;
......
......@@ -16,7 +16,7 @@ classdef ConstraintName
SPLINCOL = 4
CONST = 5
SP_POS = 6
% BLKDIAG = 7
BLKDIAG = 7
SUPP = 8
NORMLIN = 9
TOEPLITZ = 10
......@@ -55,7 +55,7 @@ classdef ConstraintName
end
function is_mat = is_mat_constraint(obj)
is_mat = obj.name == obj.SUPP || obj.name == obj.CONST || obj.name == obj.CIRC || obj.name == obj.TOEPLITZ || obj.name == obj.HANKEL;
is_mat = obj.name == obj.SUPP || obj.name == obj.CONST || obj.name == obj.CIRC || obj.name == obj.TOEPLITZ || obj.name == obj.HANKEL || obj.name == obj.BLKDIAG;
end
function str = conv2str (obj)
......@@ -84,8 +84,8 @@ classdef ConstraintName
str = 'toeplitz';
case obj.HANKEL
str = 'hankel';
%case obj.BLKDIAG;
% str = 'blkdiag'
case obj.BLKDIAG;
str = 'blockdiag';
otherwise
error('Unknown name')
end
......@@ -123,6 +123,8 @@ classdef ConstraintName
id = ConstraintName.TOEPLITZ
case 'hankel'
id = ConstraintName.HANKEL
case 'blockdiag'
id = ConstraintName.BLKDIAG
otherwise
error(err_msg)
end
......
......@@ -61,6 +61,11 @@ classdef (Abstract) ParamsFact
if(~ iscell(constraints))
error(['matfaust.factparams.ParamsFact constraints argument must be a cell array.'])
end
for i = 1:length(constraints)
if(isa(constraints{i}, 'matfaust.proj.proj_gen'))
constraints{i} = constraints{i}.constraint
end
end
for i = 1:length(constraints) %ParamsFact.TODO: check constraints length in sub-class
if(~ isa(constraints{i}, 'matfaust.factparams.ConstraintGeneric'))
error(['matfaust.factparams.ParamsFact constraints argument must contain matfaust.factparams.ConstraintGeneric objects.'])
......
......@@ -20,14 +20,14 @@ classdef ParamsHierarchical < matfaust.factparams.ParamsFact
if(iscell(fact_constraints))
for i=1:length(fact_constraints)
if(isa(fact_constraints{i}, 'matfaust.proj.proj_gen'))
fact_constraints{i} = fact_constraints{i}.constraint
fact_constraints{i} = fact_constraints{i}.constraint;
end
end
end
if(iscell(res_constraints))
for i=1:length(res_constraints)
if(isa(res_constraints{i}, 'matfaust.proj.proj_gen'))
res_constraints{i} = res_constraints{i}.constraint
res_constraints{i} = res_constraints{i}.constraint;
end
end
end
......
%==================================================
%> Functor for the BLOCKDIAG projector.
%> @brief Functor for the BLOCKDIAG projector.
%>
%> TODO
%==================================================
classdef blockdiag %< matfaust.proj.proj_gen
classdef blockdiag < matfaust.proj.proj_gen
properties
m_vec
n_vec
......@@ -13,14 +15,19 @@
M = zeros(shape(1), shape(2));
m_vec = zeros(1, length(mn_cell));
n_vec = zeros(1, length(mn_cell));
mn_mat = zeros(length(mn_cell), 2);
for i=1:length(mn_cell)
m_vec(i) = mn_cell{i}{1};
n_vec(i) = mn_cell{i}{2};
mn_mat(i,1) = mn_cell{i}{1};
mn_mat(i,2) = mn_cell{i}{2};
end
proj.m_vec = m_vec;
proj.n_vec = n_vec;
proj.normalized = false;
proj.pos = false;
proj.constraint = matfaust.factparams.ConstraintMat('blockdiag', mn_mat, varargin{:});
argc = length(varargin);
if(argc > 0)
for i=1:argc
switch(varargin{i})
......
......@@ -319,7 +319,6 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
case CONSTRAINT_NAME_HANKEL:
Faust::prox_hankel(mat, normalized, pos);
break;
break;
default:
mexErrMsgTxt("Unknown constraint name/type.");
}
......
......@@ -496,7 +496,7 @@ void setVectorFaustMat(std::vector<Faust::MatDense<FPP,Cpu> > &vecMat,mxArray *C
template<typename FPP, typename FPP2>
void getConstraint(std::vector<const Faust::ConstraintGeneric*> & consS,mxArray* mxCons)
{
mwSize bufCharLen,nbRowCons,nbColCons,nb_params;
mwSize bufCharLen,nbRowCons,nbColCons,nb_params, param_sz;
int status;
char * consName;
double paramCons;
......@@ -504,8 +504,8 @@ void getConstraint(std::vector<const Faust::ConstraintGeneric*> & consS,mxArray*
if (!mxIsCell(mxCons))
mexErrMsgTxt("tools_mex.h : getConstraint : constraint must be a cell-array. ");
nb_params = mxGetNumberOfElements(mxCons);
if (nb_params != 4)
mexErrMsgTxt("tools_mex.h : getConstraint : size of constraint must be equal to 4. ");
if (nb_params < 4 || nb_params > 5)
mexErrMsgTxt("mx2Faust.hpp: getConstraint : size of constraint must be 4 or 5. ");
// mexPrintf("getConstraint() nb_params=%d\n", nb_params);
mxConsParams=mxGetCell(mxCons,0);
......@@ -528,7 +528,8 @@ void getConstraint(std::vector<const Faust::ConstraintGeneric*> & consS,mxArray*
int const_type = get_type_constraint(consName);
faust_constraint_name consNameType=get_equivalent_constraint(consName);
if(const_type != 2 && nb_params != 4)
mexErrMsgTxt("mx2Faust.hpp: getConstraint for this constraint type (non-matrix) must be 4.");
switch(const_type)
{
case 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment