Commit af4c4a9d authored by VAN TOLL Wouter's avatar VAN TOLL Wouter

- Policy now stores cost functions as CostFunction* instead of...

- Policy now stores cost functions as CostFunction* instead of shared_ptr<CostFunction>, which could cause unexpeced bottlenecks when running on multiple threads.
- CostFunction registration now uses modern C++ templates and lambda expressions, instead of C-style macros.
- Removed costFunctionRegistration.h/.cpp.
parent 1c3e4ac9
/* UMANS: Unified Microscopic Agent Navigation Simulator
** Copyright (C) 2018-2020 Inria Rennes Bretagne Atlantique - Rainbow - Julien Pettré
**
** This program is free software: you can redistribute it and/or modify
** it under the terms of the GNU General Public License as published by
** the Free Software Foundation, either version 3 of the License, or
** (at your option) any later version.
**
** This program is distributed in the hope that it will be useful,
** but WITHOUT ANY WARRANTY; without even the implied warranty of
** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
** GNU General Public License for more details.
**
** You should have received a copy of the GNU General Public License
** along with this program. If not, see <https://www.gnu.org/licenses/>.
**
** Contact: crowd_group@inria.fr
** Website: https://project.inria.fr/crowdscience/
** See the file AUTHORS.md for a list of all contributors.
*/
#include "core/costFunctionFactory.h"
#include "CostFunctions/costFunctionRegistration.h"
#include "CostFunctions/FOEAvoidance.h"
#include "CostFunctions/GenericCost.h"
#include "CostFunctions/GoalReachingForce.h"
#include "CostFunctions/Karamouzas.h"
#include "CostFunctions/Moussaid.h"
#include "CostFunctions/ORCA.h"
#include "CostFunctions/Paris.h"
#include "CostFunctions/PLEdestrians.h"
#include "CostFunctions/PowerLaw.h"
#include "CostFunctions/RandomFunction.h"
#include "CostFunctions/RVO.h"
#include "CostFunctions/SocialForcesAvoidance.h"
#include "CostFunctions/TtcaDca.h"
#include "CostFunctions/VanToll.h"
bool registered = false;
void RegisterCostFunctions()
{
if (registered)
return;
registered = true;
REGISTER_COST_FUNCTION(FOEAvoidance);
REGISTER_COST_FUNCTION(GenericCost);
REGISTER_COST_FUNCTION(GoalReachingForce);
REGISTER_COST_FUNCTION(Karamouzas);
REGISTER_COST_FUNCTION(Moussaid);
REGISTER_COST_FUNCTION(ORCA);
REGISTER_COST_FUNCTION(PowerLaw);
REGISTER_COST_FUNCTION(RandomFunction);
REGISTER_COST_FUNCTION(SocialForcesAvoidance);
REGISTER_COST_FUNCTION(TtcaDca);
REGISTER_COST_FUNCTION(RVO);
REGISTER_COST_FUNCTION(Paris);
REGISTER_COST_FUNCTION(PLEdestrians);
REGISTER_COST_FUNCTION(VanToll);
}
\ No newline at end of file
/* UMANS: Unified Microscopic Agent Navigation Simulator
** Copyright (C) 2018-2020 Inria Rennes Bretagne Atlantique - Rainbow - Julien Pettré
**
** This program is free software: you can redistribute it and/or modify
** it under the terms of the GNU General Public License as published by
** the Free Software Foundation, either version 3 of the License, or
** (at your option) any later version.
**
** This program is distributed in the hope that it will be useful,
** but WITHOUT ANY WARRANTY; without even the implied warranty of
** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
** GNU General Public License for more details.
**
** You should have received a copy of the GNU General Public License
** along with this program. If not, see <https://www.gnu.org/licenses/>.
**
** Contact: crowd_group@inria.fr
** Website: https://project.inria.fr/crowdscience/
** See the file AUTHORS.md for a list of all contributors.
*/
#ifndef LIB_COSTFUNCTIONREGISTRATION_H
#define LIB_COSTFUNCTIONREGISTRATION_H
void RegisterCostFunctions();
#endif //LIB_COSTFUNCTIONREGISTRATION_H
......@@ -64,7 +64,7 @@ Vector2D CostFunction::GetGlobalMinimum(Agent* agent, const WorldBase* world) co
}
Vector2D CostFunction::ApproximateGlobalMinimumBySampling(Agent* agent, const WorldBase* world,
const SamplingParameters& params, const CostFunctionList_Pointers& costFunctions)
const SamplingParameters& params, const CostFunctionList& costFunctions)
{
// --- Compute the range in which samples will be taken.
......
......@@ -41,8 +41,7 @@ typedef std::vector<PhantomAgent> AgentNeighborList;
typedef std::vector<LineSegment2D> ObstacleNeighborList;
typedef std::pair<AgentNeighborList, ObstacleNeighborList> NeighborList;
typedef std::vector<std::pair<std::shared_ptr<const CostFunction>, float>> CostFunctionList;
typedef std::vector<std::pair<const CostFunction*, float>> CostFunctionList_Pointers;
typedef std::vector<std::pair<const CostFunction*, float>> CostFunctionList;
const float MaxFloat = std::numeric_limits<float>::max();
......@@ -60,9 +59,8 @@ const float MaxFloat = std::numeric_limits<float>::max();
/// <item>(optionally) parseParameters(): This method should parse any additional parameters that are specific to your cost function.
/// Make sure to call the parent version of parseParameters() here, to parse any parameters that were already defined.</item>
/// </list>
/// Finally, to be able to use your cost function, add the following line to the file core/costFunctionRegistration.cpp:
///
/// REGISTER_COST_FUNCTION(MyNewFunction)
/// Finally, to be able to use your cost function, add the following line to the file core/costFunctionFactory.cpp:
/// <code>registerCostFunction<MyNewFunction>();</code>
///
/// This will make sure that the program can dynamically create instances of your cost function when it is used in an XML file.
///
......@@ -116,12 +114,12 @@ Add a file CostFunctions/MyNewFunction.cpp:
}
</code>
Add the following lines to core/costFunctionRegistration.cpp:
Add the following lines to core/costFunctionFactory.cpp:
<code>
#include <CostFunctions/MyNewFunction.h>
//// ...
REGISTER_COST_FUNCTION(MyNewFunction)
registerCostFunction<MyNewFunction>();
</code>
*/
......@@ -187,7 +185,7 @@ public:
/// <param name="costFunctions">A list of cost functions to evaluate.</param>
/// <returns>The sample velocity for which the sum of all cost-function values is lowest.</param>
static Vector2D ApproximateGlobalMinimumBySampling(Agent* agent, const WorldBase* world,
const SamplingParameters& params, const CostFunctionList_Pointers& costFunctions);
const SamplingParameters& params, const CostFunctionList& costFunctions);
/// <summary>Parses the parameters of the cost function.</summary>
/// <remarks>By default, this method already loads the "range" parameter.
......
......@@ -20,46 +20,38 @@
*/
#include "core/costFunctionFactory.h"
#include <iostream>
CostFunctionFactory::CostFunctionFactory()
#include "CostFunctions/FOEAvoidance.h"
#include "CostFunctions/GenericCost.h"
#include "CostFunctions/GoalReachingForce.h"
#include "CostFunctions/Karamouzas.h"
#include "CostFunctions/Moussaid.h"
#include "CostFunctions/ORCA.h"
#include "CostFunctions/Paris.h"
#include "CostFunctions/PLEdestrians.h"
#include "CostFunctions/PowerLaw.h"
#include "CostFunctions/RandomFunction.h"
#include "CostFunctions/RVO.h"
#include "CostFunctions/SocialForcesAvoidance.h"
#include "CostFunctions/TtcaDca.h"
#include "CostFunctions/VanToll.h"
CostFunctionFactory::Registry CostFunctionFactory::registry = CostFunctionFactory::Registry();
void CostFunctionFactory::RegisterAllCostFunctions()
{
}
CostFunctionFactory::~CostFunctionFactory()
{
}
CostFunctionFactory::Registry& CostFunctionFactory::GetRegistry() {
static Registry registry;
return registry;
}
void CostFunctionFactory::RegisterCostFunction(const std::string &name, Creator creator) {
Registry &registry = GetRegistry();
if (registry.count(name) > 0) {
std::cerr << "Error: cost function " << name << " has already been registered." << std::endl;
return;
}
registry[name] = creator;
}
std::shared_ptr<CostFunction>
CostFunctionFactory::CreateCostFunction(const std::string& name) {
Registry &registry = GetRegistry();
if (registry.count(name) == 0) {
std::cerr << "Error: cost function " << name << " does not exist. The following cost functions are known: ";
for (auto &elm : registry) {
std::cerr << ", " << elm.first;
}
std::cerr << "." << std::endl;
return nullptr;
}
return registry[name]();
}
CostFunctionRegistrator::CostFunctionRegistrator(const std::string &name, CostFunctionFactory::Creator creator) {
CostFunctionFactory::RegisterCostFunction(name, creator);
registerCostFunction<FOEAvoidance>();
registerCostFunction<GenericCost>();
registerCostFunction<GoalReachingForce>();
registerCostFunction<Karamouzas>();
registerCostFunction<Moussaid>();
registerCostFunction<ORCA>();
registerCostFunction<PowerLaw>();
registerCostFunction<RandomFunction>();
registerCostFunction<SocialForcesAvoidance>();
registerCostFunction<TtcaDca>();
registerCostFunction<RVO>();
registerCostFunction<Paris>();
registerCostFunction<PLEdestrians>();
registerCostFunction<VanToll>();
}
\ No newline at end of file
......@@ -22,41 +22,48 @@
#ifndef _COST_FUNCTION_FACTORY_H
#define _COST_FUNCTION_FACTORY_H
#include <memory>
#include <core/costFunction.h>
#include <functional>
#include <map>
#include <string>
#include <memory>
#include <iostream>
class CostFunctionFactory
{
public:
typedef std::function < std::shared_ptr<CostFunction>() > Creator;
typedef std::map<std::string, Creator> Registry;
CostFunctionFactory();
~CostFunctionFactory();
static std::shared_ptr<CostFunction> CreateCostFunction(const std::string& name);
static void RegisterCostFunction(const std::string &name, Creator creator);
typedef std::map<std::string, std::function<CostFunction*()>> Registry;
private:
static Registry& GetRegistry();
};
static Registry registry;
template<typename Type>
std::shared_ptr<CostFunction> DefaultObjectCreator() {
return std::make_shared<Type>();
}
private:
template<typename CostFunctionType> static void registerCostFunction()
{
const std::string& name = CostFunctionType::GetName();
if (registry.count(name) > 0)
{
std::cerr << "Error: cost function " << name << " has already been registered." << std::endl;
return;
}
registry[name] = [] { return new CostFunctionType(); };
}
class CostFunctionRegistrator {
public:
CostFunctionRegistrator(const std::string &name, CostFunctionFactory::Creator creator);
};
static void RegisterAllCostFunctions();
#define REGISTER_COST_FUNCTION(classname) CostFunctionRegistrator g_##classname(classname::GetName(), DefaultObjectCreator<classname>);
static CostFunction* CreateCostFunction(const std::string& name)
{
if (registry.count(name) == 0)
{
std::cerr << "Error: cost function " << name << " does not exist. The following cost functions are known: ";
for (auto &elm : registry)
std::cerr << ", " << elm.first;
std::cerr << "." << std::endl;
return nullptr;
}
return registry[name]();
}
};
#endif
......@@ -21,16 +21,15 @@
#include <core/crowdSimulator.h>
#include "tools/csvwriter.h"
#include "core/worldInfinite.h"
#include "core/worldToric.h"
#include "core/costFunctionFactory.h"
#include "CostFunctions/costFunctionRegistration.h"
#include <tools/csvwriter.h>
#include <core/worldInfinite.h>
#include <core/worldToric.h>
#include <core/costFunctionFactory.h>
CrowdSimulator::CrowdSimulator(bool isConsoleApplication)
: isConsoleApplication_(isConsoleApplication)
{
RegisterCostFunctions();
CostFunctionFactory::RegisterAllCostFunctions();
writer_ = nullptr;
end_time_ = MaxFloat;
}
......@@ -110,10 +109,10 @@ void CrowdSimulator::RunSimulationUntilEnd()
for (int i = 0; i < nrIterations_; ++i)
{
RunSimulationSteps(1);
if (isConsoleApplication_ && i%p == 0)
if (isConsoleApplication_ && i%p == 0)
std::cout << "#" << std::flush;
}
if (isConsoleApplication_)
std::cout << "]" << std::endl;
......@@ -360,8 +359,8 @@ bool CrowdSimulator::FromConfigFile_loadSinglePolicy(const tinyxml2::XMLElement*
while (funcElement != nullptr)
{
const auto& costFunctionName = funcElement->Attribute("name");
std::shared_ptr<CostFunction> costFunction = CostFunctionFactory::CreateCostFunction(costFunctionName);
if (costFunction)
CostFunction* costFunction = CostFunctionFactory::CreateCostFunction(costFunctionName);
if (costFunction != nullptr)
pl->AddCostFunction(costFunction, CostFunctionParameters(funcElement));
funcElement = funcElement->NextSiblingElement();
......
......@@ -23,6 +23,14 @@
#include <core/agent.h>
#include <core/worldBase.h>
Policy::~Policy()
{
// delete all cost functions
for (auto& costFunction : cost_functions_)
delete costFunction.first;
cost_functions_.clear();
}
float Policy::getInteractionRange() const
{
float range = 0;
......@@ -93,13 +101,7 @@ Vector2D Policy::getBestVelocityGlobal(Agent* agent, WorldBase * world)
Vector2D Policy::getBestVelocitySampling(Agent* agent, WorldBase * world, const SamplingParameters& params)
{
// convert the cost-function list to other types of pointers
size_t f = cost_functions_.size();
CostFunctionList_Pointers costFunctions_pointers(f);
for (size_t i = 0; i < f; ++i)
costFunctions_pointers[i] = { cost_functions_[i].first.get(), cost_functions_[i].second };
return CostFunction::ApproximateGlobalMinimumBySampling(agent, world, params, costFunctions_pointers);
return CostFunction::ApproximateGlobalMinimumBySampling(agent, world, params, cost_functions_);
}
Vector2D Policy::ComputeContactForces(Agent* agent, WorldBase * world)
......@@ -138,7 +140,7 @@ Vector2D Policy::ComputeContactForces(Agent* agent, WorldBase * world)
return totalForce;
}
void Policy::AddCostFunction(const std::shared_ptr<CostFunction>& costFunction, const CostFunctionParameters &params)
void Policy::AddCostFunction(CostFunction* costFunction, const CostFunctionParameters &params)
{
float coefficient = 1;
params.ReadFloat("coeff", coefficient);
......
......@@ -132,6 +132,9 @@ public:
Policy(OptimizationMethod method, SamplingParameters params)
: optimizationMethod_(method), samplingParameters_(params) {}
/// <summary>Destroys this Policy and all cost functions inside it.</summary>
~Policy();
/// <summary>Computes and returns a new velocity for a given agent, using the cost functions and optimization method of this Policy.</summary>
/// <param name="agent">The agent for which a new velocity should be computed.</param>
/// <param name="world">The world in which the simulation takes place.</param>
......@@ -146,7 +149,7 @@ public:
/// <summary>Adds a cost function to this Policy's list of cost functions.</summary>
/// <param name="costFunction">A pointer to an alraedy created cost function.</param>
/// <param name="params">An XML object containing all parameters that the cost function may want to read.</param>
void AddCostFunction(const std::shared_ptr<CostFunction>& costFunction, const CostFunctionParameters& params);
void AddCostFunction(CostFunction* costFunction, const CostFunctionParameters& params);
/// <summary>Returns the number of cost functions used by this Policy.</summary>
size_t GetNumberOfCostFunctions() const { return cost_functions_.size(); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment