Commit ee22f044 authored by LOPEZ GANDIA Axel's avatar LOPEZ GANDIA Axel

Merge branch '54-parameters-factory' into 'master'

Resolve "Input parameters for cost functions in XML"

Closes #54

See merge request !80
parents f41ef50b f01fd748
......@@ -38,6 +38,7 @@ class FOEAvoidance : public CostFunction
private:
public:
const static std::string GetName() { return "FOEAvoidance"; }
const static std::string Name;
FOEAvoidance();
......@@ -47,6 +48,7 @@ class FOEAvoidance : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
virtual ~FOEAvoidance();
};
......
......@@ -38,6 +38,7 @@ class GenericCost : public CostFunction
public:
const static std::string Name;
const static std::string GetName() { return "Generic"; }
GenericCost();
/**
......@@ -45,6 +46,7 @@ class GenericCost : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
virtual ~GenericCost();
};
......
......@@ -38,6 +38,7 @@ class PowerLaw : public CostFunction
private:
public:
const static std::string GetName() { return "PowerLaw"; }
float tau0 = 3;
......@@ -49,6 +50,7 @@ class PowerLaw : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
virtual ~PowerLaw();
};
......
......@@ -38,6 +38,7 @@ class RandomFunction : public CostFunction
std::default_random_engine RNGengine;
std::uniform_real_distribution<float> RNGdistribution;
public:
const static std::string GetName() { return "RandomFunction"; }
int RNGSeed = 0;
const static std::string Name;
......@@ -49,6 +50,7 @@ class RandomFunction : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
virtual ~RandomFunction();
};
......
......@@ -36,6 +36,7 @@ class SocialForcesAvoidance : public CostFunction
private:
float sigma = 0.3f;
public:
const static std::string GetName() { return "SocialForcesAvoidance"; }
const static std::string Name;
SocialForcesAvoidance();
......@@ -45,6 +46,7 @@ class SocialForcesAvoidance : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
virtual ~SocialForcesAvoidance();
};
......
......@@ -36,6 +36,7 @@ class SocialForcesGoalReaching : public CostFunction
private:
public:
const static std::string GetName() { return "SocialForcesGoalReaching"; }
const static std::string Name;
SocialForcesGoalReaching();
......@@ -45,6 +46,7 @@ class SocialForcesGoalReaching : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
virtual ~SocialForcesGoalReaching();
};
......
......@@ -56,7 +56,8 @@ class TtcaDca : public CostFunction
public:
const static std::string GetName() { return "TtcaDca"; }
/*
* Name of the cost function, to be used in configuration file
*/
......@@ -75,6 +76,7 @@ class TtcaDca : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
};
#endif //LIB_TTCADCA_H
......@@ -39,6 +39,7 @@ class DirGoalReaching : public CostFunction
public:
const static std::string Name;
const static std::string GetName() { return "DirectionalGoalReaching"; }
DirGoalReaching();
......@@ -47,6 +48,7 @@ class DirGoalReaching : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
virtual ~DirGoalReaching();
};
......
......@@ -38,6 +38,7 @@ class GoalReaching : public CostFunction
private:
public:
const static std::string GetName() { return "GoalReaching"; }
const static std::string Name;
GoalReaching();
......@@ -47,6 +48,7 @@ class GoalReaching : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
virtual ~GoalReaching();
};
......
......@@ -36,6 +36,7 @@ class SocialForces : public CostFunction
private:
public:
const static std::string GetName() { return "SocialForces"; }
const static std::string Name;
......@@ -49,6 +50,7 @@ class SocialForces : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
virtual ~SocialForces();
};
......
......@@ -25,7 +25,7 @@
#ifndef LIB_COST_FUNCTION_H
#define LIB_COST_FUNCTION_H
#include "tinyxml2.h"
#include <tools/vector2D.h>
#include <string>
//#include <agent/agent.h>
......@@ -39,6 +39,30 @@ struct CostFunctionValues {
Vector2D Gradient = Vector2D(0,0);
};
/*
*
* To create a new cost function you need to create a new class and derive it (public inheritance) from CostFunction.
* Then you need to override the functions GetCostFunctionGradient(Agent* agent, WorldBase * world), parseParameters(tinyxml2::XMLElement* v)
* and add a new one const static std::string GetName() { return "MyCostFuncID"; }. Finally in the cpp file you must add REGISTER_COST_FUNCTION(MyNewFunction)
*
It should look like this:
MyNewFunction.h
class MyNewFunction : public CostFunction{
public:
const static std::string GetName() { return "MyNewFunctionId"; }
CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world) override;
void parseParameters(tinyxml2::XMLElement* v) override;
};
MyNewFunction.cpp
REGISTER_COST_FUNCTION(MyNewFunction)
*
*
*
*
*/
class CostFunction {
protected:
//unsigned int agentId_;
......@@ -46,7 +70,7 @@ class CostFunction {
//float lambdaMovementCost_;
//WorldBase* world_;
std::string name_;
float coefficient_;
float coefficient_ = 1;
public:
CostFunction();
......@@ -56,7 +80,10 @@ class CostFunction {
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world) = 0;
void setCoefficient(float v);
/**
* Parses the parameters of the cost function.
*/
virtual void parseParameters(tinyxml2::XMLElement* v);
virtual ~CostFunction();
std::string getName() const;
......
......@@ -3,13 +3,38 @@
#include <memory>
#include <core/costFunction.h>
#include <functional>
#include <map>
#include <string>
class CostFunctionFactory
{
public:
typedef std::function < std::shared_ptr<CostFunction>() > Creator;
typedef std::map<std::string, Creator> Registry;
CostFunctionFactory();
~CostFunctionFactory();
static std::shared_ptr<CostFunction> CreateCostFunction(const std::string& name);
static void RegisterCostFunction(const std::string &name, Creator creator);
private:
static Registry& GetRegistry();
};
template<typename Type>
std::shared_ptr<CostFunction> DefaultObjectCreator() {
return std::make_shared<Type>();
}
class CostFunctionRegistrator {
public:
CostFunctionRegistrator(const std::string &name, CostFunctionFactory::Creator creator);
};
#define REGISTER_COST_FUNCTION(classname) CostFunctionRegistrator g_##classname(classname::GetName(), DefaultObjectCreator<classname>);
#endif
......@@ -26,7 +26,7 @@
#ifndef LIB_POLICY_H
#define LIB_POLICY_H
#include "tinyxml2.h"
#include <core/costFunction.h>
#include <tools/vector2D.h>
#include <vector>
......@@ -43,7 +43,7 @@ class Policy {
public:
Vector2D getNewVelocity(Agent* agent, WorldBase * world);
void addCostFunction(const std::shared_ptr<CostFunction>& costFunction, float weight = 1);
void addCostFunction(const std::shared_ptr<CostFunction>& costFunction, tinyxml2::XMLElement* args);
//std::vector<double> getParameters() const;
std::vector<std::shared_ptr<CostFunction> > & getCostFunctions();
......
......@@ -23,16 +23,6 @@
** Contact: crowd_group@inria.fr
*/
//========================================================================
/*!
@file xmlparser.h
@class XMLParser
@date 22/6/2018
@brief
@author Javad Amirian, (C) 2018
*/
//========================================================================
#ifndef _XML_PARSER_H
#define _XML_PARSER_H
......
......@@ -40,7 +40,6 @@ int main( int argc, char * argv[] )
cs.setOutputDir("./output/");
cs.runMasterConfigFile("./MainConfig.xml");
//cs.runWorld(1000, 0.0333);
std::cout << "Simulation done"<< std::endl;
......
......@@ -24,6 +24,7 @@
*/
#include <CostFunctions/FOEAvoidance.h>
#include "core/costFunctionFactory.h"
#include <core/agent.h>
#include <core/worldBase.h>
#include <tools/Matrix.h>
......@@ -116,3 +117,11 @@ CostFunctionValues FOEAvoidance::GetCostFunctionGradient(Agent* agent, WorldBase
return Result;
}
void FOEAvoidance::parseParameters(tinyxml2::XMLElement * v)
{
CostFunction::parseParameters(v);
}
REGISTER_COST_FUNCTION(FOEAvoidance)
\ No newline at end of file
......@@ -26,6 +26,7 @@
#include <CostFunctions/GenericCost.h>
#include <core/agent.h>
#include <core/worldBase.h>
#include "core/costFunctionFactory.h"
#include <algorithm>
#include <iostream>
......@@ -50,3 +51,10 @@ CostFunctionValues GenericCost::GetCostFunctionGradient(Agent* agent, WorldBase
return Result;
}
void GenericCost::parseParameters(tinyxml2::XMLElement * v)
{
CostFunction::parseParameters(v);
}
REGISTER_COST_FUNCTION(GenericCost)
\ No newline at end of file
......@@ -27,6 +27,7 @@
#include <core/agent.h>
#include <core/worldBase.h>
#include <tools/Matrix.h>
#include "core/costFunctionFactory.h"
#include <algorithm>
#include <iostream>
......@@ -82,3 +83,10 @@ CostFunctionValues PowerLaw::GetCostFunctionGradient(Agent* agent, WorldBase * w
Result.TotalCost = coefficient_ * Result.TotalCost;
return Result;
}
void PowerLaw::parseParameters(tinyxml2::XMLElement * v)
{
CostFunction::parseParameters(v);
}
REGISTER_COST_FUNCTION(PowerLaw)
\ No newline at end of file
......@@ -25,6 +25,7 @@
#include <CostFunctions/RandomFunction.h>
#include <core/agent.h>
#include "core/costFunctionFactory.h"
#include <iostream>
using namespace std;
......@@ -48,3 +49,10 @@ CostFunctionValues RandomFunction::GetCostFunctionGradient(Agent* agent, WorldBa
return Result;
}
void RandomFunction::parseParameters(tinyxml2::XMLElement * v)
{
CostFunction::parseParameters(v);
}
REGISTER_COST_FUNCTION(RandomFunction)
......@@ -26,6 +26,7 @@
#include <CostFunctions/SocialForcesAvoidance.h>
#include <core/agent.h>
#include <core/worldBase.h>
#include "core/costFunctionFactory.h"
#include <algorithm>
#include <iostream>
......@@ -45,7 +46,7 @@ CostFunctionValues SocialForcesAvoidance::GetCostFunctionGradient(Agent* agent,
CostFunctionValues Result;
const Vector2D& AgentPos = agent->getPosition();
const vector<PhantomAgent> Neighbors = world->getNeighboursOfAgent(agent->getID(), 1000);
const vector<PhantomAgent> Neighbors = world->getNeighboursOfAgent(agent->getID(), 2);
Vector2D Repulsion(0, 0); float TotalCost = 0;
float dt = world->getDeltaTime();
......@@ -75,3 +76,10 @@ CostFunctionValues SocialForcesAvoidance::GetCostFunctionGradient(Agent* agent,
return Result;
}
void SocialForcesAvoidance::parseParameters(tinyxml2::XMLElement * v)
{
CostFunction::parseParameters(v);
}
REGISTER_COST_FUNCTION(SocialForcesAvoidance)
\ No newline at end of file
......@@ -26,6 +26,7 @@
#include <CostFunctions/SocialForcesGoalReaching.h>
#include <core/agent.h>
#include <core/worldBase.h>
#include "core/costFunctionFactory.h"
#include <algorithm>
#include <iostream>
......@@ -58,3 +59,10 @@ CostFunctionValues SocialForcesGoalReaching::GetCostFunctionGradient(Agent* agen
return Result;
}
void SocialForcesGoalReaching::parseParameters(tinyxml2::XMLElement * v)
{
CostFunction::parseParameters(v);
}
REGISTER_COST_FUNCTION(SocialForcesGoalReaching)
\ No newline at end of file
......@@ -30,6 +30,7 @@
#include <limits>
#include <iostream>
#include "tools/Matrix.h"
#include "core/costFunctionFactory.h"
using namespace std;
const std::string TtcaDca::Name = "TtcaDca";
......@@ -152,3 +153,10 @@ CostFunctionValues TtcaDca::GetCostFunctionGradient(Agent* agent, WorldBase * wo
}
return Result;
}
void TtcaDca::parseParameters(tinyxml2::XMLElement * v)
{
CostFunction::parseParameters(v);
}
REGISTER_COST_FUNCTION(TtcaDca)
\ No newline at end of file
......@@ -25,6 +25,7 @@
#include <CostFunctions/directionalGoalReaching.h>
#include <core/agent.h>
#include "core/costFunctionFactory.h"
#include <iostream>
using namespace std;
......@@ -53,3 +54,11 @@ CostFunctionValues DirGoalReaching::GetCostFunctionGradient(Agent* agent, WorldB
return Result;
}
void DirGoalReaching::parseParameters(tinyxml2::XMLElement * v)
{
CostFunction::parseParameters(v);
}
REGISTER_COST_FUNCTION(DirGoalReaching);
\ No newline at end of file
......@@ -25,6 +25,7 @@
#include <CostFunctions/goalReaching.h>
#include <core/agent.h>
#include "core/costFunctionFactory.h"
#include <iostream>
using namespace std;
......@@ -57,3 +58,10 @@ CostFunctionValues GoalReaching::GetCostFunctionGradient(Agent* agent, WorldBase
return Result;
}
void GoalReaching::parseParameters(tinyxml2::XMLElement * v)
{
CostFunction::parseParameters(v);
}
REGISTER_COST_FUNCTION(GoalReaching)
\ No newline at end of file
......@@ -26,6 +26,7 @@
#include <CostFunctions/socialforces.h>
#include <core/agent.h>
#include <core/worldBase.h>
#include "core/costFunctionFactory.h"
#include <algorithm>
#include <iostream>
......@@ -78,3 +79,26 @@ CostFunctionValues SocialForces::GetCostFunctionGradient(Agent* agent, WorldBase
return Result;
}
void SocialForces::parseParameters(tinyxml2::XMLElement * v)
{
CostFunction::parseParameters(v);
}
REGISTER_COST_FUNCTION(SocialForces)
/*
class AngryCostFunction : public CostFunction {
public:
const static std::string GetName() { return "Angry"; }
CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world) override {
std::cerr << "Angry" << std::endl;
return CostFunctionValues();
}
void parseParameters(tinyxml2::XMLElement* v) override {
CostFunction::parseParameters(v);
}
};
REGISTER_COST_FUNCTION(AngryCostFunction)
*/
......@@ -22,7 +22,7 @@
**
** Contact: crowd_group@inria.fr
*/
#include "tinyxml2.h"
#include <core/costFunction.h>
#include <core/worldBase.h>
......@@ -46,8 +46,12 @@ CostFunction::~CostFunction() {
}
void CostFunction::setCoefficient(float v) {
coefficient_ = v;
void CostFunction::parseParameters(tinyxml2::XMLElement* v) {
float val;
if (v->QueryFloatAttribute("coeff", &val) == tinyxml2::XML_SUCCESS) {
coefficient_ = val;
}
}
/*
......
#include "core/costFunctionFactory.h"
#include "CostFunctions/FOEAvoidance.h"
#include "CostFunctions/GenericCost.h"
#include "CostFunctions/goalReaching.h"
#include "CostFunctions/socialforces.h"
#include "CostFunctions/SocialForcesAvoidance.h"
#include "CostFunctions/SocialForcesGoalReaching.h"
#include "CostFunctions/TtcaDca.h"
#include "CostFunctions/RandomFunction.h"
#include "CostFunctions/PowerLaw.h"
#include "CostFunctions/directionalGoalReaching.h"
#include <iostream>
CostFunctionFactory::CostFunctionFactory()
{
......@@ -19,39 +9,38 @@ CostFunctionFactory::~CostFunctionFactory()
{
}
CostFunctionFactory::Registry& CostFunctionFactory::GetRegistry() {
static Registry registry;
return registry;
}
void CostFunctionFactory::RegisterCostFunction(const std::string &name, Creator creator) {
Registry &registry = GetRegistry();
if (registry.count(name) > 0) {
std::cerr << "Error: cost function " << name << " has already been registered." << std::endl;
return;
}
registry[name] = creator;
}
std::shared_ptr<CostFunction>
CostFunctionFactory::CreateCostFunction(const std::string& name) {
if (name == GenericCost::Name) {
return std::make_shared<GenericCost>();
}
if (name == FOEAvoidance::Name) {
return std::make_shared<FOEAvoidance>();
}
if (name == GoalReaching::Name) {
return std::make_shared<GoalReaching>();
}
if (name == SocialForces::Name) {
return std::make_shared<SocialForces>();
}
if (name == SocialForcesAvoidance::Name) {
return std::make_shared<SocialForcesAvoidance>();
}
if (name == SocialForcesGoalReaching::Name) {
return std::make_shared<SocialForcesGoalReaching>();
}
if (name == TtcaDca::Name) {
return std::make_shared<TtcaDca>();
}
if (name == RandomFunction::Name) {
return std::make_shared<RandomFunction>();
Registry &registry = GetRegistry();
if (registry.count(name) == 0) {
std::cerr << "Error: cost function " << name << " does not exist, known cost functions: ";
for (auto &elm : registry) {
std::cerr << ", " << elm.first;
}
std::cerr << "." << std::endl;
return nullptr;
}
if (name == PowerLaw::Name) {
return std::make_shared<PowerLaw>();
}
if (name == DirGoalReaching::Name) {
return std::make_shared<DirGoalReaching>();
}
std::cerr << "The given cost function name does not exist" << std::endl;
return nullptr;
return registry[name]();
}
CostFunctionRegistrator::CostFunctionRegistrator(const std::string &name, CostFunctionFactory::Creator creator) {
CostFunctionFactory::RegisterCostFunction(name, creator);
}
\ No newline at end of file
......@@ -98,7 +98,7 @@ void CrowdSimulator::stepWorld()
poss[agent->getID()] = agent->getPosition();
writer.appendPedPositions(poss, t);
writer.flush(); //! @todo: should be removed after test
//writer.flush(); //! @todo: should be removed after test
}
/**
......
......@@ -22,7 +22,7 @@
**
** Contact: crowd_group@inria.fr
*/
#include "tinyxml2.h"
#include <core/policy.h>
#include <core/agent.h>
#include <core/worldBase.h>
......@@ -53,7 +53,7 @@ Vector2D Policy::getNewVelocity(Agent* agent, WorldBase * world) {
return NewVel.getnormalized()*clamp(NewVel.magnitude(), 0.f, agent->getMaximumSpeed());
}
void Policy::addCostFunction(const std::shared_ptr<CostFunction>& costFunction, float weight) {
costFunction->setCoefficient(weight);
void Policy::addCostFunction(const std::shared_ptr<CostFunction>& costFunction, tinyxml2::XMLElement* args) {
costFunction->parseParameters(args);
cost_functions_.push_back(costFunction);
}
......@@ -151,15 +151,12 @@ void XMLParser::load(const std::string &filename, CrowdSimulator * crowdsimulato
std::string costFunctionName;
costFunctionName = funcElement->Attribute("name");
float coeff;
funcElement->QueryFloatAttribute("coeff", &coeff);
std::shared_ptr<CostFunction> costFunction;
costFunction =
CostFunctionFactory().CreateCostFunction(costFunctionName);
pl->addCostFunction(costFunction, coeff);
CostFunctionFactory::CreateCostFunction(costFunctionName);
if (costFunction) {
pl->addCostFunction(costFunction, funcElement);
}
} while ((funcElement = funcElement->NextSiblingElement()) != NULL);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment