Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 97e04449 authored by hhakim's avatar hhakim
Browse files

Add an enum GradientCalcOptMode and a boolean in Faust::Params* structures to...

Add an enum GradientCalcOptMode and a boolean in Faust::Params* structures to decide to enable or not the multiplication optimization used in Palm4MSA::compute_grad_over_c and make this option available in pyfaust.factparams structures.

matfaust same modification is yet to do.
parent e0010b8e
Branches
Tags
No related merge requests found
Showing
with 99 additions and 57 deletions
......@@ -178,6 +178,7 @@ namespace Faust
const bool verbose;
const bool isUpdateWayR2L;
const bool isConstantStepSize;
const GradientCalcOptMode gradCalcOptMode;
bool isCComputed;
bool isGradComputed;
bool isProjectionComputed;
......
......@@ -99,6 +99,7 @@ Faust::Palm4MSA<FPP,DEVICE,FPP2>::Palm4MSA(const Faust::MatDense<FPP,DEVICE>& M,
isGlobal(isGlobal_),
isInit(false),
c(FPP2(1)/params_.step_size),
gradCalcOptMode(params_.gradCalcOptMode),
blas_handle(blasHandle),
is_complex(typeid(data.getData()[0]) == typeid(complex<float>) || typeid(data.getData()[0]) == typeid(complex<double>)
),
......@@ -120,26 +121,27 @@ Faust::Palm4MSA<FPP,DEVICE,FPP2>::Palm4MSA(const Faust::MatDense<FPP,DEVICE>& M,
template<typename FPP,Device DEVICE,typename FPP2>
Faust::Palm4MSA<FPP,DEVICE,FPP2>::Palm4MSA(const Faust::ParamsPalm<FPP,DEVICE,FPP2>& params_palm_,const Faust::BlasHandle<DEVICE> blasHandle,const bool isGlobal_/*=false*/) :
stop_crit(params_palm_.stop_crit),
data(params_palm_.data),
m_lambda(params_palm_.init_lambda),
m_nbFact(params_palm_.nbFact),
S(params_palm_.init_fact),
RorL(vector<Faust::MatDense<FPP,DEVICE> >(2)),
LorR(Faust::MatDense<FPP,DEVICE>(params_palm_.init_fact[0].getNbRow())),
const_vec(params_palm_.cons),
m_indFact(0),
m_indIte(-1),
verbose(params_palm_.isVerbose),
isUpdateWayR2L(params_palm_.isUpdateWayR2L),
isConstantStepSize(params_palm_.isConstantStepSize),
isGradComputed(false),
isProjectionComputed(false),
isLastFact(false),
isConstraintSet(false),
isGlobal(isGlobal_),
c(FPP2(1)/params_palm_.step_size),
blas_handle(blasHandle),
stop_crit(params_palm_.stop_crit),
data(params_palm_.data),
m_lambda(params_palm_.init_lambda),
m_nbFact(params_palm_.nbFact),
S(params_palm_.init_fact),
RorL(vector<Faust::MatDense<FPP,DEVICE> >(2)),
LorR(Faust::MatDense<FPP,DEVICE>(params_palm_.init_fact[0].getNbRow())),
const_vec(params_palm_.cons),
m_indFact(0),
m_indIte(-1),
verbose(params_palm_.isVerbose),
isUpdateWayR2L(params_palm_.isUpdateWayR2L),
isConstantStepSize(params_palm_.isConstantStepSize),
isGradComputed(false),
isProjectionComputed(false),
isLastFact(false),
isConstraintSet(false),
isGlobal(isGlobal_),
c(FPP2(1)/params_palm_.step_size),
gradCalcOptMode(params_palm_.gradCalcOptMode),
blas_handle(blasHandle),
is_complex(typeid(data.getData()[0]) == typeid(complex<float>) || typeid(data.getData()[0]) == typeid(complex<double>)
),
TorH(is_complex?'H':'T')
......@@ -246,7 +248,7 @@ t_local_compute_grad_over_c.start();
error = data;
Faust::MatDense<FPP,DEVICE> tmp1,tmp3;
if (idx==0 || idx==1) // computing L*S first, then (L*S)*R
if (idx==0 || idx==1 || gradCalcOptMode == DISABLED) // computing L*S first, then (L*S)*R
{
if (!isUpdateWayR2L)
{
......@@ -323,7 +325,7 @@ sprintf(nomFichier,"error_1_%d_device.tmp",cmpt);*/
}
}
if (idx==0 || idx==2) // computing L'*error first, then (L'*error)*R'
if (idx==0 || idx==2 || gradCalcOptMode == DISABLED) // computing L'*error first, then (L'*error)*R'
{
if (!isUpdateWayR2L)
{
......
......@@ -61,6 +61,14 @@
namespace Faust
{
/** Modes avaiable for the possible optimization in compute_grad_over_c() */
enum GradientCalcOptMode
{
DISABLED, // no optimization at all
INTERNAL_OPT, //the optimization defined internally must be used
EXTERNAL_OPT //the optimization defined externally in faust_linear_algebra must be used
};
template<typename FPP,Device DEVICE> class MatDense;
......@@ -86,7 +94,8 @@ namespace Faust
const bool isFactSideLeft_ = defaultFactSideLeft ,
const FPP2 init_lambda_ = defaultLambda ,
const bool constant_step_size_ = defaultConstantStepSize,
const FPP2 step_size_ = defaultStepSize);
const FPP2 step_size_ = defaultStepSize,
const GradientCalcOptMode gradCalcOptMode = defaultGradCalcOptMode);
/*!
......@@ -123,7 +132,7 @@ namespace Faust
*/
Params(
const faust_unsigned_int nbRow_,
const faust_unsigned_int nbCol_,
const faust_unsigned_int nbCol_,
const unsigned int nbFact_,
const std::vector<std::vector<const Faust::ConstraintGeneric*>> & cons_,
const std::vector<Faust::MatDense<FPP,DEVICE> >& init_fact_,
......@@ -134,14 +143,14 @@ namespace Faust
const bool isFactSideLeft_ = defaultFactSideLeft ,
const FPP2 init_lambda_ = defaultLambda ,
const bool constant_step_size_ = defaultConstantStepSize,
const FPP2 step_size_ = defaultStepSize);
const FPP2 step_size_ = defaultStepSize,
const GradientCalcOptMode gradCalcOptMode = defaultGradCalcOptMode);
Params();
void init_from_file(const char* filename);
void check_constraint_validity();
void check_bool_validity();
virtual void Display() const;
~Params(){}
......@@ -167,6 +176,7 @@ namespace Faust
FPP2 init_lambda;
bool isConstantStepSize;
FPP2 step_size;
GradientCalcOptMode gradCalcOptMode;
//default value
static const int defaultNiter1;
......@@ -179,6 +189,7 @@ namespace Faust
static const FPP2 defaultStepSize;
static const FPP defaultDecreaseSpeed;
static const FPP defaultResiduumPercent;
static const GradientCalcOptMode defaultGradCalcOptMode;
//const int nb_rows; // number of rows of the first factor
//const int nb_cols; // number of columns of the last factor
......
......@@ -71,6 +71,8 @@ template<typename FPP,Device DEVICE,typename FPP2> const FPP2 Faust::Params<FPP,
template<typename FPP,Device DEVICE,typename FPP2> const FPP Faust::Params<FPP,DEVICE,FPP2>::defaultDecreaseSpeed = 1.25;
template<typename FPP,Device DEVICE,typename FPP2> const FPP Faust::Params<FPP,DEVICE,FPP2>::defaultResiduumPercent = 1.4;
template<typename FPP,Device DEVICE,typename FPP2> const Faust::GradientCalcOptMode Faust::Params<FPP,DEVICE,FPP2>::defaultGradCalcOptMode = INTERNAL_OPT;
template<typename FPP,Device DEVICE,typename FPP2>
void Faust::Params<FPP,DEVICE,FPP2>::check_constraint_validity()
......@@ -124,7 +126,8 @@ Faust::Params<FPP,DEVICE,FPP2>::Params(
const bool isFactSideLeft_ , /* = false */
const FPP2 init_lambda_ /* = 1.0 */,
const bool constant_step_size_,
const FPP2 step_size_):
const FPP2 step_size_,
const GradientCalcOptMode gradCalcOptMode/* default INTERNAL_OPT */):
m_nbRow(nbRow_),
m_nbCol(nbCol_),
m_nbFact(nbFact_),
......@@ -136,7 +139,8 @@ Faust::Params<FPP,DEVICE,FPP2>::Params(
isFactSideLeft(isFactSideLeft_),
init_lambda(init_lambda_),
isConstantStepSize(constant_step_size_),
step_size(step_size_)
step_size(step_size_),
gradCalcOptMode(gradCalcOptMode)
{
if (nbFact_ <= 2)
{
......@@ -234,9 +238,10 @@ Faust::Params<FPP,DEVICE,FPP2>::Params(
const bool isFactSideLeft_ /* = false */,
const FPP2 init_lambda_ /* = 1.0 */,
const bool constant_step_size_ ,
const FPP2 step_size_ ) :
const FPP2 step_size_ ,
const GradientCalcOptMode gradCalcOptMode /* default INTERNAL_OPT */) :
m_nbRow(nbRow_),
m_nbCol(nbCol_),
m_nbCol(nbCol_),
m_nbFact(nbFact_),
cons(cons_),
init_fact(init_fact_),
......@@ -247,7 +252,8 @@ Faust::Params<FPP,DEVICE,FPP2>::Params(
isFactSideLeft(isFactSideLeft_),
init_lambda(init_lambda_),
isConstantStepSize(constant_step_size_),
step_size(step_size_)
step_size(step_size_),
gradCalcOptMode(gradCalcOptMode)
{
check_constraint_validity();
......@@ -271,7 +277,8 @@ Faust::Params<FPP,DEVICE,FPP2>::Params() : m_nbRow(0),
init_fact(std::vector<Faust::MatDense<FPP,DEVICE> >()),
init_lambda(defaultLambda),
isConstantStepSize(defaultConstantStepSize),
step_size(defaultStepSize)
step_size(defaultStepSize),
gradCalcOptMode(defaultGradCalcOptMode)
{}
......@@ -300,7 +307,7 @@ void Faust::Params<FPP,DEVICE,FPP2>::Display() const
std::cout<<"Matrix : nbRow "<<m_nbRow<<" NbCol : "<< m_nbCol<<std::endl;
std::cout<<"stop_crit_2facts : "<<stop_crit_2facts.get_crit()<<std::endl;
std::cout<<"stop_crit_global : "<<stop_crit_global.get_crit()<<std::endl;
std::cout << "gradCalcOptMode: "<< gradCalcOptMode << std::endl;
/*cout<<"INIT_FACTS :"<<endl;
for (int L=0;L<init_fact.size();L++)init_fact[L].Display();*/
......
......@@ -30,7 +30,7 @@ namespace Faust
const bool isFactSideLeft = Params<FPP,DEVICE,FPP2>::defaultFactSideLeft,
const FPP2 init_lambda = Params<FPP,DEVICE,FPP2>::defaultLambda,
const bool constant_step_size = Params<FPP,DEVICE,FPP2>::defaultConstantStepSize,
const FPP2 step_size = Params<FPP,DEVICE,FPP2>::defaultStepSize): Params<FPP, DEVICE, FPP2>(nbRow, nbCol, nbFact, cons, init_fact, stop_crit_2facts, stop_crit_global, isVerbose, isUpdateWayR2L, isFactSideLeft, init_lambda, constant_step_size, step_size), init_D(init_D)
const FPP2 step_size = Params<FPP,DEVICE,FPP2>::defaultStepSize, const Faust::GradientCalcOptMode gradCalcOptMode = Params<FPP,DEVICE,FPP2>::defaultGradCalcOptMode): Params<FPP, DEVICE, FPP2>(nbRow, nbCol, nbFact, cons, init_fact, stop_crit_2facts, stop_crit_global, isVerbose, isUpdateWayR2L, isFactSideLeft, init_lambda, constant_step_size, step_size, gradCalcOptMode), init_D(init_D)
{
}
......@@ -49,7 +49,7 @@ namespace Faust
const bool isFactSideLeft = Params<FPP,DEVICE,FPP2>::defaultFactSideLeft,
const FPP2 init_lambda = Params<FPP,DEVICE,FPP2>::defaultLambda,
const bool constant_step_size = Params<FPP,DEVICE,FPP2>::defaultConstantStepSize,
const FPP2 step_size = Params<FPP,DEVICE,FPP2>::defaultStepSize): Params<FPP, DEVICE, FPP2>(nbRow, nbCol, nbFact, cons, init_fact, stop_crit_2facts, stop_crit_global, isVerbose, isUpdateWayR2L, isFactSideLeft, init_lambda, constant_step_size, step_size), init_D(nbRow, nbCol)
const FPP2 step_size = Params<FPP,DEVICE,FPP2>::defaultStepSize, const Faust::GradientCalcOptMode gradCalcOptMode = Params<FPP,DEVICE,FPP2>::defaultGradCalcOptMode): Params<FPP, DEVICE, FPP2>(nbRow, nbCol, nbFact, cons, init_fact, stop_crit_2facts, stop_crit_global, isVerbose, isUpdateWayR2L, isFactSideLeft, init_lambda, constant_step_size, step_size, gradCalcOptMode), init_D(nbRow, nbCol)
{
init_D.setZeros();
// set init_D from diagonal vector init_D_diag
......@@ -71,7 +71,7 @@ namespace Faust
const bool isFactSideLeft = Params<FPP,DEVICE,FPP2>::defaultFactSideLeft,
const FPP2 init_lambda = Params<FPP,DEVICE,FPP2>::defaultLambda,
const bool constant_step_size = Params<FPP,DEVICE,FPP2>::defaultConstantStepSize,
const FPP2 step_size = Params<FPP,DEVICE,FPP2>::defaultStepSize): ParamsFGFT<FPP, DEVICE, FPP2>(nbRow, nbCol, nbFact, cons, init_fact, Faust::Vect<FPP,DEVICE>(nbRow, init_D_diag), stop_crit_2facts, stop_crit_global, isVerbose, isUpdateWayR2L, isFactSideLeft, init_lambda, constant_step_size, step_size)
const FPP2 step_size = Params<FPP,DEVICE,FPP2>::defaultStepSize, const Faust::GradientCalcOptMode gradCalcOptMode = Params<FPP,DEVICE,FPP2>::defaultGradCalcOptMode): ParamsFGFT<FPP, DEVICE, FPP2>(nbRow, nbCol, nbFact, cons, init_fact, Faust::Vect<FPP,DEVICE>(nbRow, init_D_diag), stop_crit_2facts, stop_crit_global, isVerbose, isUpdateWayR2L, isFactSideLeft, init_lambda, constant_step_size, step_size, gradCalcOptMode)
{
}
......
......@@ -50,6 +50,7 @@
#endif
#include "faust_StoppingCriterion.h"
#include "faust_ConstraintGeneric.h"
#include "faust_Params.h"
/*! \class Faust::ParamsPalm
......@@ -81,7 +82,8 @@ namespace Faust
const bool isUpdateWayR2L_ = defaultUpdateWayR2L ,
const FPP2 init_lambda_ = defaultLambda,
const bool constant_step_size_ = defaultConstantStepSize,
const FPP2 step_size_ = defaultStepSize);
const FPP2 step_size_ = defaultStepSize,
const GradientCalcOptMode gradCalcOptMode = Faust::Params<FPP,DEVICE,FPP2>::defaultGradCalcOptMode);
void check_constraint_validity();
ParamsPalm();
......@@ -102,6 +104,7 @@ namespace Faust
bool isConstantStepSize;
FPP2 step_size;
FPP2 init_lambda;
GradientCalcOptMode gradCalcOptMode;
void Display() const;
void init_factors();
......
......@@ -92,7 +92,8 @@ Faust::ParamsPalm<FPP,DEVICE,FPP2>::ParamsPalm(
const bool isUpdateWayR2L_ /* = false */,
const FPP2 init_lambda_ /* = 1.0 */,
const bool constant_step_size_,
const FPP2 step_size_) :
const FPP2 step_size_,
const GradientCalcOptMode gradCalcOptMode /* default INTERNAL_OPT*/) :
data(data_),
nbFact(nbFact_),
cons(cons_),
......@@ -102,13 +103,14 @@ Faust::ParamsPalm<FPP,DEVICE,FPP2>::ParamsPalm(
isUpdateWayR2L(isUpdateWayR2L_),
init_lambda(init_lambda_),
isConstantStepSize(constant_step_size_),
step_size(step_size_)
step_size(step_size_),
gradCalcOptMode(gradCalcOptMode)
{
check_constraint_validity();
}
template<typename FPP,Device DEVICE,typename FPP2>
Faust::ParamsPalm<FPP,DEVICE,FPP2>::ParamsPalm() : data(0,0),nbFact(0),cons(std::vector<const Faust::ConstraintGeneric*>()),init_lambda(defaultLambda),isConstantStepSize(defaultConstantStepSize),step_size(defaultStepSize){}
Faust::ParamsPalm<FPP,DEVICE,FPP2>::ParamsPalm() : data(0,0),nbFact(0),cons(std::vector<const Faust::ConstraintGeneric*>()),init_lambda(defaultLambda),isConstantStepSize(defaultConstantStepSize),step_size(defaultStepSize), gradCalcOptMode(Faust::Params<FPP,DEVICE,FPP2>::defaultGradCalcOptMode) {}
template<typename FPP,Device DEVICE,typename FPP2>
void Faust::ParamsPalm<FPP,DEVICE,FPP2>::init_factors()
......@@ -160,6 +162,7 @@ void Faust::ParamsPalm<FPP,DEVICE,FPP2>::Display() const
std::cout<<"step_size : "<<step_size<<std::endl;
std::cout<<"data : nbRow "<<data.getNbRow()<<" NbCol : "<< data.getNbCol()<<std::endl;
std::cout<<"stop_crit : "<<stop_crit.get_crit()<<std::endl;
std::cout << "gradCalcOptMode: "<< gradCalcOptMode << std::endl;
/*cout<<"INIT_FACTS :"<<endl;
for (int L=0;L<init_fact.size();L++)init_fact[L].Display();*/
......
......@@ -25,7 +25,7 @@ namespace Faust
const bool isVerbose_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultVerbosity ,
const bool isUpdateWayR2L_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultUpdateWayR2L ,
const FPP2 init_lambda_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultLambda,
const FPP2 step_size_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultStepSize) : ParamsPalm<FPP, DEVICE, FPP2>(data_, nbFact_, cons_, init_fact_, stop_crit_, isVerbose_, isUpdateWayR2L_, init_lambda_, true /*constant_step_size is always true for Palm4MSAFGFT */, step_size_), init_D(init_D) {}
const FPP2 step_size_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultStepSize, const Faust::GradientCalcOptMode gradCalcOptMode = Params<FPP,DEVICE,FPP2>::defaultGradCalcOptMode) : ParamsPalm<FPP, DEVICE, FPP2>(data_, nbFact_, cons_, init_fact_, stop_crit_, isVerbose_, isUpdateWayR2L_, init_lambda_, true /*constant_step_size is always true for Palm4MSAFGFT */, step_size_, gradCalcOptMode), init_D(init_D) {}
ParamsPalmFGFT() : ParamsPalm<FPP,DEVICE,FPP2>(), init_D(0,0) {}
......@@ -38,7 +38,7 @@ namespace Faust
const bool isVerbose_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultVerbosity ,
const bool isUpdateWayR2L_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultUpdateWayR2L ,
const FPP2 init_lambda_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultLambda,
const FPP2 step_size_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultStepSize);
const FPP2 step_size_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultStepSize,const Faust::GradientCalcOptMode gradCalcOptMode = Params<FPP,DEVICE,FPP2>::defaultGradCalcOptMode);
ParamsPalmFGFT(const Faust::MatDense<FPP,DEVICE>& data_,
const int nbFact_,
......@@ -49,7 +49,7 @@ namespace Faust
const bool isVerbose_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultVerbosity ,
const bool isUpdateWayR2L_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultUpdateWayR2L ,
const FPP2 init_lambda_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultLambda,
const FPP2 step_size_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultStepSize) : ParamsPalmFGFT<FPP,DEVICE,FPP2>(data_, nbFact_, cons_, init_fact_, Faust::Vect<FPP,DEVICE>(data_.getNbRow(), init_D_diag), stop_crit_, isVerbose_, isUpdateWayR2L_, init_lambda_, step_size_) {}
const FPP2 step_size_ = ParamsPalm<FPP,DEVICE,FPP2>::defaultStepSize, const Faust::GradientCalcOptMode gradCalcOptMode = Params<FPP,DEVICE,FPP2>::defaultGradCalcOptMode) : ParamsPalmFGFT<FPP,DEVICE,FPP2>(data_, nbFact_, cons_, init_fact_, Faust::Vect<FPP,DEVICE>(data_.getNbRow(), init_D_diag), stop_crit_, isVerbose_, isUpdateWayR2L_, init_lambda_, step_size_, gradCalcOptMode) {}
MatDense<FPP,DEVICE> init_D;
......
......@@ -8,7 +8,7 @@ template<typename FPP,Device DEVICE,typename FPP2>
const bool isVerbose_ ,
const bool isUpdateWayR2L_ ,
const FPP2 init_lambda_ ,
const FPP2 step_size_) : ParamsPalm<FPP, DEVICE, FPP2>(data_, nbFact_, cons_, init_fact_, stop_crit_, isVerbose_, isUpdateWayR2L_, init_lambda_, true /*constant_step_size is always true for Palm4MSAFGFT */, step_size_), init_D(data_.getNbRow(), data_.getNbCol())
const FPP2 step_size_, const Faust::GradientCalcOptMode gradCalcOptMode) : ParamsPalm<FPP, DEVICE, FPP2>(data_, nbFact_, cons_, init_fact_, stop_crit_, isVerbose_, isUpdateWayR2L_, init_lambda_, true /*constant_step_size is always true for Palm4MSAFGFT */, step_size_), init_D(data_.getNbRow(), data_.getNbCol())
{
init_D.setZeros();
// set init_D from diagonal vector init_D_diag
......
......@@ -577,9 +577,12 @@ class ParamsFact(ABC):
<b/> See also ParamsHierarchical, ParamsPalm4MSA
"""
DISABLED_OPT = 0
INTERNAL_OPT = 1
EXTERNAL_OPT = 2
def __init__(self, num_facts, is_update_way_R2L, init_lambda,
constraints, step_size, constant_step_size,
is_verbose):
is_verbose, grad_calc_opt_mode=INTERNAL_OPT):
self.num_facts = num_facts
self.is_update_way_R2L = is_update_way_R2L
self.init_lambda = init_lambda
......@@ -590,6 +593,7 @@ class ParamsFact(ABC):
self.constraints = constraints
self.is_verbose = is_verbose
self.constant_step_size = constant_step_size
self.grad_calc_opt_mode = grad_calc_opt_mode
@abstractmethod
def is_mat_consistent(self, M):
......@@ -614,7 +618,8 @@ class ParamsHierarchical(ParamsFact):
stop_crit2, is_update_way_R2L=False, init_lambda=1.0,
step_size=10.0**-16, constant_step_size=False,
is_fact_side_left=False,
is_verbose=False):
is_verbose=False,
grad_calc_opt_mode=ParamsFact.INTERNAL_OPT):
if(not isinstance(fact_constraints, list) and not
isinstance(fact_constraints, ConstraintList)):
raise TypeError('fact_constraints must be a list or a'
......@@ -632,11 +637,12 @@ class ParamsHierarchical(ParamsFact):
constraints = fact_constraints + res_constraints
stop_crits = [ stop_crit1, stop_crit2 ]
super(ParamsHierarchical, self).__init__(num_facts,
is_update_way_R2L,
init_lambda,
constraints, step_size,
constant_step_size,
is_verbose)
is_update_way_R2L,
init_lambda,
constraints, step_size,
constant_step_size,
is_verbose,
grad_calc_opt_mode)
self.stop_crits = stop_crits
self.is_fact_side_left = is_fact_side_left
if((not isinstance(stop_crits, list) and not isinstance(stop_crits,
......@@ -778,7 +784,7 @@ class ParamsPalm4MSA(ParamsFact):
is_update_way_R2L=False, init_lambda=1.0,
step_size=10.0**-16,
constant_step_size=False,
is_verbose=False):
is_verbose=False, grad_calc_opt_mode=ParamsFact.INTERNAL_OPT):
if(not isinstance(constraints, list) and not
isinstance(constraints, ConstraintList)):
raise TypeError('constraints argument must be a list or a'
......@@ -788,7 +794,7 @@ class ParamsPalm4MSA(ParamsFact):
init_lambda,
constraints, step_size,
constant_step_size,
is_verbose)
is_verbose, grad_calc_opt_mode)
if(init_facts != None and (not isinstance(init_facts, list) and not isinstance(init_facts,
tuple) or
len(init_facts) != num_facts)):
......
......@@ -149,6 +149,7 @@ cdef extern from "FaustFact.h":
unsigned int num_constraints
bool is_verbose
bool constant_step_size
unsigned int grad_calc_opt_mode
cdef cppclass PyxParamsFactPalm4MSA[FPP,FPP2](PyxParamsFact[FPP,FPP2]):
FPP** init_facts # num_facts elts
......
......@@ -68,6 +68,7 @@ class PyxParamsFact
unsigned int num_constraints;
bool is_verbose;
bool constant_step_size;
unsigned int grad_calc_opt_mode;
};
template<typename FPP, typename FPP2 = double>
......
......@@ -137,6 +137,7 @@ void prepare_fact(const FPP* mat, const unsigned int num_rows, const unsigned in
cout << "p->step_size: " << p->step_size << endl;
cout << "p->is_verbose: " << p->is_verbose << endl;
cout << "p->constant_step_size: " << p->constant_step_size << endl;
cout << "p->grad_calc_opt_mode: " << p->grad_calc_opt_mode << endl;
}
PyxConstraintInt* cons_int;
PyxConstraintScalar<FPP2>* cons_real;
......@@ -228,12 +229,12 @@ FaustCoreCpp<FPP>* fact_palm4MSA_gen(FPP* mat, unsigned int num_rows, unsigned i
if(p_fft = dynamic_cast<PyxParamsFactPalm4MSAFFT<FPP,FPP2>*>(p))
{
params = new ParamsPalmFGFT<FPP,Cpu,FPP2>(inMat, p->num_facts, cons, initFacts, p_fft->init_D, crit, p->is_verbose, p->is_update_way_R2L, p->init_lambda, p->step_size);
params = new ParamsPalmFGFT<FPP,Cpu,FPP2>(inMat, p->num_facts, cons, initFacts, p_fft->init_D, crit, p->is_verbose, p->is_update_way_R2L, p->init_lambda, p->step_size, static_cast<Faust::GradientCalcOptMode>(p->grad_calc_opt_mode));
palm = new Palm4MSAFGFT<FPP,Cpu,FPP2>(*static_cast<ParamsPalmFGFT<FPP,Cpu,FPP2>*>(params),blasHandle,true);
}
else {
params = new ParamsPalm<FPP,Cpu,FPP2>(inMat, p->num_facts, cons, initFacts, crit, p->is_verbose, p->is_update_way_R2L, p->init_lambda, p->constant_step_size, p->step_size);
params = new ParamsPalm<FPP,Cpu,FPP2>(inMat, p->num_facts, cons, initFacts, crit, p->is_verbose, p->is_update_way_R2L, p->init_lambda, p->constant_step_size, p->step_size, static_cast<Faust::GradientCalcOptMode>(p->grad_calc_opt_mode));
palm = new Palm4MSA<FPP,Cpu,FPP2>(*params,blasHandle,true);
}
......@@ -338,6 +339,8 @@ FaustCoreCpp<FPP>* fact_hierarchical_gen(FPP* mat, FPP* mat2, unsigned int num_r
cout << "stop_crits[1].error_treshold: " << p->stop_crits[1].error_treshold << endl;
cout << "stop_crits[1].num_its: " << p->stop_crits[1].num_its << endl;
cout << "stop_crits[1].max_num_its: " << p->stop_crits[1].max_num_its << endl;
cout << "p->grad_calc_opt_mode: " << p->grad_calc_opt_mode << endl;
}
prepare_fact(mat, num_rows, num_cols, p, cons);
......@@ -358,12 +361,12 @@ FaustCoreCpp<FPP>* fact_hierarchical_gen(FPP* mat, FPP* mat2, unsigned int num_r
if(p_fft = dynamic_cast<PyxParamsHierarchicalFactFFT<FPP,FPP2>*>(p))
{
inMat2 = new Faust::MatDense<FPP,Cpu>(mat2, num_rows, num_cols);
params = new Faust::ParamsFGFT<FPP,Cpu,FPP2>(p->num_rows, p->num_cols, p->num_facts, cons_, initFacts_deft, p_fft->init_D, crit0, crit1, p->is_verbose, p->is_update_way_R2L, p->is_fact_side_left, p->init_lambda, p->constant_step_size, p->step_size);
params = new Faust::ParamsFGFT<FPP,Cpu,FPP2>(p->num_rows, p->num_cols, p->num_facts, cons_, initFacts_deft, p_fft->init_D, crit0, crit1, p->is_verbose, p->is_update_way_R2L, p->is_fact_side_left, p->init_lambda, p->constant_step_size, p->step_size, static_cast<Faust::GradientCalcOptMode>(p->grad_calc_opt_mode));
hierFact = new HierarchicalFactFGFT<FPP,Cpu,FPP2>(inMat, *inMat2, *(static_cast<ParamsFGFT<FPP,Cpu,FPP2>*>(params)), blasHandle, spblasHandle);
}
else
{
params = new Params<FPP,Cpu,FPP2>(p->num_rows, p->num_cols, p->num_facts, cons_, initFacts_deft, crit0, crit1, p->is_verbose, p->is_update_way_R2L, p->is_fact_side_left, p->init_lambda, p->constant_step_size, p->step_size);
params = new Params<FPP,Cpu,FPP2>(p->num_rows, p->num_cols, p->num_facts, cons_, initFacts_deft, crit0, crit1, p->is_verbose, p->is_update_way_R2L, p->is_fact_side_left, p->init_lambda, p->constant_step_size, p->step_size, static_cast<Faust::GradientCalcOptMode>(p->grad_calc_opt_mode));
hierFact = new HierarchicalFact<FPP,Cpu,FPP2>(inMat, *params, blasHandle, spblasHandle);
}
......
......@@ -1013,6 +1013,7 @@ cdef class FaustFact:
cpp_params.is_update_way_R2L = p.is_update_way_R2L
cpp_params.init_lambda = p.init_lambda
cpp_params.step_size = p.step_size
cpp_params.grad_calc_opt_mode = p.grad_calc_opt_mode
cpp_params.stop_crit = cpp_stop_crit
cpp_params.init_facts = <double**> \
PyMem_Malloc(sizeof(double*)*p.num_facts)
......@@ -1036,6 +1037,7 @@ cdef class FaustFact:
cpp_params_cplx.is_update_way_R2L = p.is_update_way_R2L
cpp_params_cplx.init_lambda = p.init_lambda
cpp_params_cplx.step_size = p.step_size
cpp_params_cplx.grad_calc_opt_mode = p.grad_calc_opt_mode
cpp_params_cplx.stop_crit = cpp_stop_crit
cpp_params_cplx.init_facts = <complex**> \
PyMem_Malloc(sizeof(complex*)*p.num_facts)
......@@ -1230,6 +1232,7 @@ cdef class FaustFact:
cpp_params.is_update_way_R2L = p.is_update_way_R2L
cpp_params.init_lambda = p.init_lambda
cpp_params.step_size = p.step_size
cpp_params.grad_calc_opt_mode = p.grad_calc_opt_mode
cpp_params.stop_crits = cpp_stop_crits
cpp_params.is_verbose = p.is_verbose
cpp_params.is_fact_side_left = p.is_fact_side_left
......@@ -1362,6 +1365,7 @@ cdef class FaustFact:
cpp_params_cplx.is_update_way_R2L = p.is_update_way_R2L
cpp_params_cplx.init_lambda = p.init_lambda
cpp_params_cplx.step_size = p.step_size
cpp_params_cplx.grad_calc_opt_mode = p.grad_calc_opt_mode
cpp_params_cplx.stop_crits = cpp_stop_crits
cpp_params_cplx.is_verbose = p.is_verbose
cpp_params_cplx.is_fact_side_left = p.is_fact_side_left
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment