Commit 75cefaeb authored by LOPEZ GANDIA Axel's avatar LOPEZ GANDIA Axel

Reworked parameter parsing for cost functions. Added range parameter to cost functions.

parent 091aeca9
......@@ -48,7 +48,7 @@ class FOEAvoidance : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
void parseParameters(const CostFunctionParameters & params) override;
virtual ~FOEAvoidance();
};
......
......@@ -45,7 +45,7 @@ class GenericCost : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
void parseParameters(const CostFunctionParameters & params) override;
virtual ~GenericCost();
};
......
......@@ -49,7 +49,7 @@ class PowerLaw : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
void parseParameters(const CostFunctionParameters & params) override;
virtual ~PowerLaw();
};
......
......@@ -49,7 +49,7 @@ class RandomFunction : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
void parseParameters(const CostFunctionParameters & params) override;
virtual ~RandomFunction();
};
......
......@@ -46,7 +46,7 @@ class SocialForcesAvoidance : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
void parseParameters(const CostFunctionParameters & params) override;
virtual ~SocialForcesAvoidance();
};
......
......@@ -45,7 +45,7 @@ class SocialForcesGoalReaching : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
void parseParameters(const CostFunctionParameters & params) override;
virtual ~SocialForcesGoalReaching();
};
......
......@@ -75,7 +75,7 @@ class TtcaDca : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
void parseParameters(const CostFunctionParameters & params) override;
};
#endif //LIB_TTCADCA_H
......@@ -48,7 +48,7 @@ class DirGoalReaching : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
void parseParameters(const CostFunctionParameters & params) override;
virtual ~DirGoalReaching();
};
......
......@@ -47,7 +47,7 @@ class GoalReaching : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
void parseParameters(const CostFunctionParameters & params) override;
virtual ~GoalReaching();
};
......
......@@ -51,7 +51,7 @@ class SocialForces : public CostFunction
* Return the values to the policy to update the agent.
*/
virtual CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world);
void parseParameters(tinyxml2::XMLElement* v) override;
void parseParameters(const CostFunctionParameters & params) override;
virtual ~SocialForces();
};
......
......@@ -23,7 +23,7 @@
#ifndef LIB_COST_FUNCTION_H
#define LIB_COST_FUNCTION_H
#include "tinyxml2.h"
#include "tools/xmlparser.h"
#include <tools/vector2D.h>
#include <string>
//#include <agent/agent.h>
......@@ -52,7 +52,7 @@ Example:
public:
const static std::string GetName() { return "MyNewFunctionId"; }
CostFunctionValues GetCostFunctionGradient(Agent* agent, WorldBase * world) override;
void parseParameters(tinyxml2::XMLElement* v) override;
void parseParameters(const CostFunctionParameters & params) override;
};
MyNewFunction.cpp
......@@ -60,7 +60,7 @@ Example:
CostFunctionValues MyNewFunction::GetCostFunctionGradient(Agent* agent, WorldBase * world) {
//Implementation
}
void MyNewFunction::parseParameters(tinyxml2::XMLElement* v) {
void MyNewFunction::parseParameters(const CostFunctionParameters & params) {
//Implementation
}
......@@ -77,7 +77,8 @@ class CostFunction {
//float lambdaMovementCost_;
//WorldBase* world_;
std::string name_;
float coefficient_ = 1;
float coefficient_ = 1; //tipically scaling factor of the result of the cost function
float range_ = 100; //range of interaction. Typically the range for neighbor search
public:
CostFunction();
......@@ -90,7 +91,7 @@ class CostFunction {
/**
* Parses the parameters of the cost function.
*/
virtual void parseParameters(tinyxml2::XMLElement* v);
virtual void parseParameters(const CostFunctionParameters & params);
virtual ~CostFunction();
std::string getName() const;
......
......@@ -41,7 +41,7 @@ class Policy {
public:
Vector2D getNewVelocity(Agent* agent, WorldBase * world);
void addCostFunction(const std::shared_ptr<CostFunction>& costFunction, tinyxml2::XMLElement* args);
void addCostFunction(const std::shared_ptr<CostFunction>& costFunction, const CostFunctionParameters &params);
//std::vector<double> getParameters() const;
std::vector<std::shared_ptr<CostFunction> > & getCostFunctions();
......
......@@ -27,7 +27,10 @@
#include <string>
#include <iostream>
#include <vector>
#include "core/crowdSimulator.h"
#include "tools/vector2D.h"
#include "tinyxml2.h"
class CrowdSimulator;
class XMLParser
{
......@@ -36,9 +39,24 @@ public:
void loadMasterConfig(const std::string &filename, CrowdSimulator * crowdsimulator);
void load(const std::string &filename, CrowdSimulator * crowdsimulator);
//std::vector<Agent*> agents;
//std::vector<Obstacle*> obstacles;
};
/*
*
* Helper class to read cost function parameters from XML file.
*
*/
class CostFunctionParameters {
tinyxml2::XMLElement* xml;
public:
CostFunctionParameters(tinyxml2::XMLElement* v) : xml(v) {}
bool ReadInt(const std::string &name, int &value) const;
bool ReadFloat(const std::string &name, float &value) const;
bool ReadBool(const std::string &name, bool &value) const;
bool ReadString(const std::string &name, std::string &value) const;
//Parses a vector. In the xml each component is written as individual floats with the following format: namex=x namey=y
bool ReadVector(const std::string &name, Vector2D &value) const;
};
#endif // _XML_PARSER_H
......@@ -75,7 +75,7 @@ CostFunctionValues FOEAvoidance::GetCostFunctionGradient(Agent* agent, WorldBase
float Cost = 0;
float gradtheta = 0;
float gradv = 0;
vector<PhantomAgent> Neighbors = world->getNeighboursOfAgent(agent->getID(), 30);
vector<PhantomAgent> Neighbors = world->getNeighboursOfAgent(agent->getID(), range_);
if (!(Velocity.magnitude() > 0)) {
return Result;
}
......@@ -116,9 +116,9 @@ CostFunctionValues FOEAvoidance::GetCostFunctionGradient(Agent* agent, WorldBase
return Result;
}
void FOEAvoidance::parseParameters(tinyxml2::XMLElement * v)
void FOEAvoidance::parseParameters(const CostFunctionParameters & params)
{
CostFunction::parseParameters(v);
CostFunction::parseParameters(params);
}
......
......@@ -50,9 +50,9 @@ CostFunctionValues GenericCost::GetCostFunctionGradient(Agent* agent, WorldBase
return Result;
}
void GenericCost::parseParameters(tinyxml2::XMLElement * v)
void GenericCost::parseParameters(const CostFunctionParameters & params)
{
CostFunction::parseParameters(v);
CostFunction::parseParameters(params);
}
REGISTER_COST_FUNCTION(GenericCost)
\ No newline at end of file
......@@ -52,7 +52,7 @@ CostFunctionValues PowerLaw::GetCostFunctionGradient(Agent* agent, WorldBase * w
Vector2D AgentVelocity = agent->getVelocity();
float AgentRadius = agent->getRadius();
vector<PhantomAgent> neighbors = world->getNeighboursOfAgent(agent->getID(), 10000);
vector<PhantomAgent> neighbors = world->getNeighboursOfAgent(agent->getID(), range_);
for (PhantomAgent& other : neighbors) {
if (other.realAgent == agent) {
continue;
......@@ -82,13 +82,10 @@ CostFunctionValues PowerLaw::GetCostFunctionGradient(Agent* agent, WorldBase * w
return Result;
}
void PowerLaw::parseParameters(tinyxml2::XMLElement * v)
void PowerLaw::parseParameters(const CostFunctionParameters & params)
{
CostFunction::parseParameters(v);
float val;
if (v->QueryFloatAttribute("tau0", &val) == tinyxml2::XML_SUCCESS) {
tau0 = val;
}
CostFunction::parseParameters(params);
params.ReadFloat("tau0", tau0);
}
REGISTER_COST_FUNCTION(PowerLaw)
\ No newline at end of file
......@@ -48,9 +48,9 @@ CostFunctionValues RandomFunction::GetCostFunctionGradient(Agent* agent, WorldBa
return Result;
}
void RandomFunction::parseParameters(tinyxml2::XMLElement * v)
void RandomFunction::parseParameters(const CostFunctionParameters & params)
{
CostFunction::parseParameters(v);
CostFunction::parseParameters(params);
}
REGISTER_COST_FUNCTION(RandomFunction)
......@@ -44,7 +44,7 @@ CostFunctionValues SocialForcesAvoidance::GetCostFunctionGradient(Agent* agent,
CostFunctionValues Result;
const Vector2D& AgentPos = agent->getPosition();
const vector<PhantomAgent> Neighbors = world->getNeighboursOfAgent(agent->getID(), 2);
const vector<PhantomAgent> Neighbors = world->getNeighboursOfAgent(agent->getID(), range_);
Vector2D Repulsion(0, 0); float TotalCost = 0;
float dt = world->getDeltaTime();
......@@ -75,13 +75,10 @@ CostFunctionValues SocialForcesAvoidance::GetCostFunctionGradient(Agent* agent,
return Result;
}
void SocialForcesAvoidance::parseParameters(tinyxml2::XMLElement * v)
void SocialForcesAvoidance::parseParameters(const CostFunctionParameters & params)
{
CostFunction::parseParameters(v);
float val;
if (v->QueryFloatAttribute("sigma", &val) == tinyxml2::XML_SUCCESS) {
sigma = val;
}
CostFunction::parseParameters(params);
params.ReadFloat("sigma", sigma);
}
REGISTER_COST_FUNCTION(SocialForcesAvoidance)
\ No newline at end of file
......@@ -58,9 +58,9 @@ CostFunctionValues SocialForcesGoalReaching::GetCostFunctionGradient(Agent* agen
return Result;
}
void SocialForcesGoalReaching::parseParameters(tinyxml2::XMLElement * v)
void SocialForcesGoalReaching::parseParameters(const CostFunctionParameters & params)
{
CostFunction::parseParameters(v);
CostFunction::parseParameters(params);
}
REGISTER_COST_FUNCTION(SocialForcesGoalReaching)
\ No newline at end of file
......@@ -75,7 +75,7 @@ CostFunctionValues TtcaDca::GetCostFunctionGradient(Agent* agent, WorldBase * wo
float GradS = 0;
//get the neighbourhood
std::vector<PhantomAgent> agts = world->getNeighboursOfAgent(agent->getID(), 50);
std::vector<PhantomAgent> agts = world->getNeighboursOfAgent(agent->getID(), range_);
//std::vector<Agent* > agts = world->getNeighboursOfAgent(agent->getID(), agent->getRadius()*20);
int NumAgentsVisible = 0;
......@@ -152,16 +152,11 @@ CostFunctionValues TtcaDca::GetCostFunctionGradient(Agent* agent, WorldBase * wo
return Result;
}
void TtcaDca::parseParameters(tinyxml2::XMLElement * v)
void TtcaDca::parseParameters(const CostFunctionParameters & params)
{
CostFunction::parseParameters(v);
float val;
if (v->QueryFloatAttribute("sigmaTtca", &val) == tinyxml2::XML_SUCCESS) {
sigTtca_ = val;
}
if (v->QueryFloatAttribute("sigmaDca", &val) == tinyxml2::XML_SUCCESS) {
sigDca_ = val;
}
CostFunction::parseParameters(params);
params.ReadFloat("sigmaTtca", sigTtca_);
params.ReadFloat("sigmaDca", sigDca_);
}
REGISTER_COST_FUNCTION(TtcaDca)
\ No newline at end of file
......@@ -53,9 +53,10 @@ CostFunctionValues DirGoalReaching::GetCostFunctionGradient(Agent* agent, WorldB
return Result;
}
void DirGoalReaching::parseParameters(tinyxml2::XMLElement * v)
void DirGoalReaching::parseParameters(const CostFunctionParameters & params)
{
CostFunction::parseParameters(v);
CostFunction::parseParameters(params);
}
......
......@@ -57,9 +57,9 @@ CostFunctionValues GoalReaching::GetCostFunctionGradient(Agent* agent, WorldBase
return Result;
}
void GoalReaching::parseParameters(tinyxml2::XMLElement * v)
void GoalReaching::parseParameters(const CostFunctionParameters & params)
{
CostFunction::parseParameters(v);
CostFunction::parseParameters(params);
}
REGISTER_COST_FUNCTION(GoalReaching)
\ No newline at end of file
......@@ -53,7 +53,7 @@ CostFunctionValues SocialForces::GetCostFunctionGradient(Agent* agent, WorldBase
const Vector2D& Atraction = -AtractorForce*(GoalDir*PreferedSpeed - CurrentVel);
const vector<PhantomAgent> Neighbors = world->getNeighboursOfAgent(agent->getID(), 100);
const vector<PhantomAgent> Neighbors = world->getNeighboursOfAgent(agent->getID(), range_);
Vector2D Repulsion(0, 0);
float dt = 1;
......@@ -78,22 +78,12 @@ CostFunctionValues SocialForces::GetCostFunctionGradient(Agent* agent, WorldBase
return Result;
}
void SocialForces::parseParameters(tinyxml2::XMLElement * v)
void SocialForces::parseParameters(const CostFunctionParameters & params)
{
CostFunction::parseParameters(v);
float val;
if (v->QueryFloatAttribute("coeff", &val) == tinyxml2::XML_SUCCESS) {
coefficient_ = val;
}
if (v->QueryFloatAttribute("AtractionForce", &val) == tinyxml2::XML_SUCCESS) {
AtractorForce = val;
}
if (v->QueryFloatAttribute("RepulsionForce", &val) == tinyxml2::XML_SUCCESS) {
RepulsionForce = val;
}
if (v->QueryFloatAttribute("sigma", &val) == tinyxml2::XML_SUCCESS) {
sigma = val;
}
CostFunction::parseParameters(params);
params.ReadFloat("AtractionForce", AtractorForce);
params.ReadFloat("RepulsionForce", RepulsionForce);
params.ReadFloat("sigma", sigma);
}
REGISTER_COST_FUNCTION(SocialForces)
\ No newline at end of file
......@@ -44,24 +44,9 @@ CostFunction::~CostFunction() {
}
void CostFunction::parseParameters(tinyxml2::XMLElement* v) {
float val;
if (v->QueryFloatAttribute("coeff", &val) == tinyxml2::XML_SUCCESS) {
coefficient_ = val;
}
void CostFunction::parseParameters(const CostFunctionParameters & params) {
coefficient_ = 1; //Default values
range_ = 100;
params.ReadFloat("coeff", coefficient_);
params.ReadFloat("range", range_);
}
/*
float CostFunction::goalCost() {
return 0.0;
}
float CostFunction::avoidanceCost() {
return 0.0;
}
float CostFunction::totalCost() {
return 0.0;
}
*/
......@@ -20,7 +20,7 @@
**
** Contact : crowd_group@inria.fr
*/
#include "tinyxml2.h"
#include "tools/xmlparser.h"
#include <core/policy.h>
#include <core/agent.h>
#include <core/worldBase.h>
......@@ -51,7 +51,7 @@ Vector2D Policy::getNewVelocity(Agent* agent, WorldBase * world) {
return NewVel.getnormalized()*clamp(NewVel.magnitude(), 0.f, agent->getMaximumSpeed());
}
void Policy::addCostFunction(const std::shared_ptr<CostFunction>& costFunction, tinyxml2::XMLElement* args) {
costFunction->parseParameters(args);
void Policy::addCostFunction(const std::shared_ptr<CostFunction>& costFunction, const CostFunctionParameters &params) {
costFunction->parseParameters(params);
cost_functions_.push_back(costFunction);
}
......@@ -21,13 +21,14 @@
** Contact : crowd_group@inria.fr
*/
#include "tinyxml2.h"
#include "tools/xmlparser.h"
#include "core/agent.h"
#include "core/obstacle.h"
#include "core/worldBase.h"
#include "core/worldToric.h"
#include "core/costFunctionFactory.h"
#include "core/crowdSimulator.h"
#include <fstream>
#include <map>
......@@ -153,7 +154,7 @@ void XMLParser::load(const std::string &filename, CrowdSimulator * crowdsimulato
costFunction =
CostFunctionFactory::CreateCostFunction(costFunctionName);
if (costFunction) {
pl->addCostFunction(costFunction, funcElement);
pl->addCostFunction(costFunction, CostFunctionParameters(funcElement));
}
} while ((funcElement = funcElement->NextSiblingElement()) != NULL);
......@@ -250,3 +251,46 @@ void XMLParser::load(const std::string &filename, CrowdSimulator * crowdsimulato
} while ((agentElement = agentElement->NextSiblingElement()) != NULL);
}
bool CostFunctionParameters::ReadInt(const std::string &name, int &value) const {
if (xml->QueryAttribute(name.c_str(), &value) == tinyxml2::XML_SUCCESS) {
return true;
}
return false;
}
bool CostFunctionParameters::ReadFloat(const std::string &name, float &value) const {
if (xml->QueryAttribute(name.c_str(), &value) == tinyxml2::XML_SUCCESS) {
return true;
}
return false;
}
bool CostFunctionParameters::ReadBool(const std::string &name, bool &value) const {
if (xml->QueryAttribute(name.c_str(), &value) == tinyxml2::XML_SUCCESS) {
return true;
}
return false;
}
bool CostFunctionParameters::ReadString(const std::string &name, std::string &value) const {
const char * str = 0;
if (xml->QueryStringAttribute(name.c_str(), &str) == tinyxml2::XML_SUCCESS) {
value = std::string(str);
delete[] str;
return true;
}
return false;
}
bool CostFunctionParameters::ReadVector(const std::string &name, Vector2D &value) const {
float x, y;
bool success = true;
if (!ReadFloat(name + "x", x)) {
success = false;
}
if (!ReadFloat(name + "y", y)) {
success = false;
}
if (success) {
value.set(x, y);
}
return success;
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment