Commit cd415494 authored by CARDOSI Paul's avatar CARDOSI Paul
Browse files

Merge branch 'taskwait-equivalent' into 'master'

Implementation of finish method of task graph.

See merge request bramas/spetabaru!29
parents 4935fafd 31f3f7ec
#include <optional>
#include <mutex>
#include <algorithm>
#include "SpComputeEngine.hpp"
#include "TaskGraph/SpAbstractTaskGraph.hpp"
#include "Tasks/SpAbstractTask.hpp"
#include "SpWorker.hpp"
void SpComputeEngine::addGraph(SpAbstractTaskGraph* tg) {
if(tg) {
tg->setComputeEngine(this);
taskGraphs.push_back(tg);
}
}
void SpComputeEngine::stopIfNotAlreadyStopped() {
if(!hasBeenStopped) {
{
......@@ -31,3 +25,10 @@ void SpComputeEngine::stopIfNotAlreadyStopped() {
hasBeenStopped = true;
}
}
void SpComputeEngine::wait(SpWorker& worker, SpAbstractTaskGraph* atg) {
std::unique_lock<std::mutex> ceLock(ceMutex);
updateWorkerCounters<false, true>(worker.getType(), +1);
ceCondVar.wait(ceLock, [&]() { return worker.hasBeenStopped() || areThereAnyWorkersToMigrate() || areThereAnyReadyTasks() || (atg && atg->isFinished());});
updateWorkerCounters<false, true>(worker.getType(), -1);
}
......@@ -18,7 +18,6 @@ class SpComputeEngine {
private:
small_vector<std::unique_ptr<SpWorker>> workers;
small_vector<SpAbstractTaskGraph*> taskGraphs;
std::mutex ceMutex;
std::condition_variable ceCondVar;
std::mutex migrationMutex;
......@@ -36,13 +35,6 @@ private:
private:
void wakeUpWaitingWorkers() {
{
std::unique_lock<std::mutex> workerLock(ceMutex);
}
ceCondVar.notify_all();
}
auto sendWorkersToInternal(SpComputeEngine *otherCe, const SpWorker::SpWorkerType wt, const long int maxCount, const bool allowBusyWorkersToBeDetached) {
small_vector<std::unique_ptr<SpWorker>> res;
using iter_t = small_vector<std::unique_ptr<SpWorker>>::iterator;
......@@ -174,12 +166,7 @@ private:
}
}
void wait(SpWorker& worker) {
std::unique_lock<std::mutex> ceLock(ceMutex);
updateWorkerCounters<false, true>(worker.getType(), +1);
ceCondVar.wait(ceLock, [&]() { return worker.hasBeenStopped() || areThereAnyWorkersToMigrate() || areThereAnyReadyTasks();});
updateWorkerCounters<false, true>(worker.getType(), -1);
}
void wait(SpWorker& worker, SpAbstractTaskGraph* atg);
auto getCeToMigrateTo() {
return ceToMigrateTo;
......@@ -200,23 +187,23 @@ private:
return migrationSignalingCounter.fetch_sub(1, std::memory_order_release);
}
friend void SpWorker::start();
friend void SpWorker::waitOnCe(SpComputeEngine*);
friend void SpWorker::waitOnCe(SpComputeEngine* inCe, SpAbstractTaskGraph* atg);
friend void SpWorker::doLoop(SpAbstractTaskGraph* atg);
public:
explicit SpComputeEngine(small_vector_base<std::unique_ptr<SpWorker>>&& inWorkers = SpWorker::createDefaultWorkerTeam())
: workers(), taskGraphs(), ceMutex(), ceCondVar(), migrationMutex(), migrationCondVar(), prioSched(), nbWorkersToMigrate(0),
explicit SpComputeEngine(small_vector_base<std::unique_ptr<SpWorker>>&& inWorkers)
: workers(), ceMutex(), ceCondVar(), migrationMutex(), migrationCondVar(), prioSched(), nbWorkersToMigrate(0),
migrationSignalingCounter(0), workerTypeToMigrate(SpWorker::SpWorkerType::CPU_WORKER), ceToMigrateTo(nullptr), nbAvailableCpuWorkers(0),
nbAvailableGpuWorkers(0), totalNbCpuWorkers(0), totalNbGpuWorkers(0), hasBeenStopped(false) {
addWorkers(std::move(inWorkers));
}
explicit SpComputeEngine() : SpComputeEngine(small_vector<std::unique_ptr<SpWorker>, 0>{}) {}
~SpComputeEngine() {
stopIfNotAlreadyStopped();
}
void addGraph(SpAbstractTaskGraph* tg);
void pushTask(SpAbstractTask* t) {
prioSched.push(t);
wakeUpWaitingWorkers();
......@@ -246,6 +233,13 @@ public:
}
void stopIfNotAlreadyStopped();
void wakeUpWaitingWorkers() {
{
std::unique_lock<std::mutex> ceLock(ceMutex);
}
ceCondVar.notify_all();
}
};
#endif
......@@ -4,63 +4,81 @@
std::atomic<long int> SpWorker::totalNbThreadsCreated = 1;
thread_local SpWorker* workerForThread = nullptr;
void SpWorker::start() {
if(!t.joinable()) {
t = std::thread([&]() {
SpUtils::SetThreadId(threadId);
SpWorker::setWorkerForThread(this);
while(!stopFlag.load(std::memory_order_relaxed)) {
SpComputeEngine* saveCe = nullptr;
// Using memory order acquire on ce.load to form release/acquire pair
// I think we could use memory order consume as all the code that follows depends on the load of ce (through saveCe).
if((saveCe = ce.load(std::memory_order_acquire))) {
if(saveCe->areThereAnyWorkersToMigrate()) {
if(saveCe->areWorkersToMigrateOfType(wt)) {
auto previousNbOfWorkersToMigrate = saveCe->fetchDecNbOfWorkersToMigrate();
if(previousNbOfWorkersToMigrate > 0) {
SpComputeEngine* newCe = saveCe->getCeToMigrateTo();
ce.store(newCe, std::memory_order_relaxed);
auto previousMigrationSignalingCounterVal = saveCe->fetchDecMigrationSignalingCounter();
if(previousMigrationSignalingCounterVal == 1) {
saveCe->notifyMigrationFinished();
}
continue;
}
}
}
doLoop(nullptr);
});
}
}
void SpWorker::waitOnCe(SpComputeEngine* inCe, SpAbstractTaskGraph* atg) {
inCe->wait(*this, atg);
}
void SpWorker::setWorkerForThread(SpWorker *w) {
workerForThread = w;
}
SpWorker* SpWorker::getWorkerForThread() {
return workerForThread;
}
void SpWorker::doLoop(SpAbstractTaskGraph* inAtg) {
while(!stopFlag.load(std::memory_order_relaxed) && (!inAtg || !inAtg->isFinished())) {
SpComputeEngine* saveCe = nullptr;
// Using memory order acquire on ce.load to form release/acquire pair
// I think we could use memory order consume as all the code that follows depends on the load of ce (through saveCe).
if((saveCe = ce.load(std::memory_order_acquire))) {
if(saveCe->areThereAnyWorkersToMigrate()) {
if(saveCe->areWorkersToMigrateOfType(wt)) {
auto previousNbOfWorkersToMigrate = saveCe->fetchDecNbOfWorkersToMigrate();
if(saveCe->areThereAnyReadyTasks()){
SpAbstractTask* task = saveCe->getTask();
if(task) {
SpAbstractTaskGraph* atg = task->getAbstractTaskGraph();
atg->preTaskExecution(task);
execute(task);
atg->postTaskExecution(task);
continue;
if(previousNbOfWorkersToMigrate > 0) {
SpComputeEngine* newCe = saveCe->getCeToMigrateTo();
ce.store(newCe, std::memory_order_relaxed);
auto previousMigrationSignalingCounterVal = saveCe->fetchDecMigrationSignalingCounter();
if(previousMigrationSignalingCounterVal == 1) {
saveCe->notifyMigrationFinished();
}
continue;
}
}
}
if(saveCe->areThereAnyReadyTasks()){
SpAbstractTask* task = saveCe->getTask();
if(task) {
SpAbstractTaskGraph* atg = task->getAbstractTaskGraph();
atg->preTaskExecution(task);
execute(task);
waitOnCe(saveCe);
} else {
idleWait();
atg->postTaskExecution(task);
continue;
}
}
});
waitOnCe(saveCe, inAtg);
} else {
idleWait();
}
}
}
void SpWorker::waitOnCe(SpComputeEngine* inCe) {
inCe->wait(*this);
}
......@@ -12,6 +12,7 @@
#include "Utils/small_vector.hpp"
class SpComputeEngine;
class SpAbstractTaskGraph;
class SpWorker {
public:
......@@ -36,6 +37,9 @@ public:
static auto createDefaultWorkerTeam() {
return createATeamOfNCpuWorkers(SpUtils::DefaultNumThreads());
}
static void setWorkerForThread(SpWorker *w);
static SpWorker* getWorkerForThread();
private:
const SpWorkerType wt;
......@@ -104,7 +108,7 @@ private:
workerConditionVariable.wait(workerLock, [&]() { return stopFlag.load(std::memory_order_relaxed) || ce.load(std::memory_order_relaxed); });
}
void waitOnCe(SpComputeEngine* inCe);
void waitOnCe(SpComputeEngine* inCe, SpAbstractTaskGraph* atg);
friend class SpComputeEngine;
......@@ -130,6 +134,8 @@ public:
}
void start();
void doLoop(SpAbstractTaskGraph* inAtg);
};
#endif
......@@ -31,7 +31,7 @@ public:
///////////////////////////////////////////////////////////////////////////
explicit SpRuntime(const int inNumThreads = SpUtils::DefaultNumThreads()) : tg(), ce(SpWorker::createATeamOfNCpuWorkers(inNumThreads)) {
ce.addGraph(std::addressof(tg));
tg.computeOn(ce);
}
///////////////////////////////////////////////////////////////////////////
......
......@@ -26,10 +26,11 @@ class SpPrioScheduler{
mutable std::mutex mutexReadyTasks;
//! Contains the tasks that can be executed
std::priority_queue<SpAbstractTask*, small_vector<SpAbstractTask*>, ComparePrio > tasksReady;
std::atomic<int> nbReadyTasks;
public:
explicit SpPrioScheduler() {
explicit SpPrioScheduler() : mutexReadyTasks(), tasksReady(), nbReadyTasks(0) {
}
// No copy or move
......@@ -39,18 +40,19 @@ public:
SpPrioScheduler& operator=(SpPrioScheduler&&) = delete;
int getNbTasks() const{
std::unique_lock<std::mutex> locker(mutexReadyTasks);
return int(tasksReady.size());
return nbReadyTasks;
}
int push(SpAbstractTask* newTask){
std::unique_lock<std::mutex> locker(mutexReadyTasks);
nbReadyTasks++;
tasksReady.push(newTask);
return 1;
}
int pushTasks(small_vector_base<SpAbstractTask*>& tasks) {
std::unique_lock<std::mutex> locker(mutexReadyTasks);
nbReadyTasks += int(tasks.size());
for(auto t : tasks) {
tasksReady.push(t);
}
......@@ -60,6 +62,7 @@ public:
SpAbstractTask* pop(){
std::unique_lock<std::mutex> locker(mutexReadyTasks);
if(tasksReady.size()){
nbReadyTasks--;
auto res = tasksReady.top();
tasksReady.pop();
return res;
......
......@@ -6,7 +6,7 @@ class SpAbstractTask;
class SpAbstractToKnowReady{
public:
virtual ~SpAbstractToKnowReady(){}
virtual void thisTaskIsReady(SpAbstractTask*) = 0;
virtual void thisTaskIsReady(SpAbstractTask*, const bool isNotCalledInAContextOfTaskCreation) = 0;
};
......
......@@ -46,12 +46,16 @@ class SpTasksManager{
//! Number of tasks that are ready
std::atomic<int> nbReadyTasks;
//! Number of finished tasks
std::atomic<int> nbFinishedTasks;
//! To protect commute locking
std::mutex mutexCommute;
small_vector<SpAbstractTask*> readyTasks;
template <const bool isNotCalledInAContextOfTaskCreation>
void insertIfReady(SpAbstractTask* aTask){
if(aTask->isState(SpTaskState::WAITING_TO_BE_READY)){
aTask->takeControl();
......@@ -72,7 +76,8 @@ class SpTasksManager{
aTask->setState(SpTaskState::READY);
aTask->releaseControl();
informAllReady(aTask);
informAllReady<isNotCalledInAContextOfTaskCreation>(aTask);
if(!ce) {
readyTasks.push_back(aTask);
......@@ -99,32 +104,27 @@ class SpTasksManager{
std::mutex listenersReadyMutex;
small_vector<SpAbstractToKnowReady*> listenersReady;
template <const bool isNotCalledInAContextOfTaskCreation>
void informAllReady(SpAbstractTask* aTask){
if(lockerByThread0 == false || SpUtils::GetThreadId() != 0){
if constexpr (isNotCalledInAContextOfTaskCreation){
listenersReadyMutex.lock();
}
for(SpAbstractToKnowReady* listener : listenersReady){
listener->thisTaskIsReady(aTask);
listener->thisTaskIsReady(aTask, isNotCalledInAContextOfTaskCreation);
}
if(lockerByThread0 == false || SpUtils::GetThreadId() != 0){
if constexpr (isNotCalledInAContextOfTaskCreation){
listenersReadyMutex.unlock();
}
}
std::atomic<bool> lockerByThread0;
public:
void lockListenersReadyMutex(){
assert(lockerByThread0 == false);
assert(SpUtils::GetThreadId() == 0);
lockerByThread0 = true;
listenersReadyMutex.lock();
}
void unlockListenersReadyMutex(){
assert(lockerByThread0 == true);
assert(SpUtils::GetThreadId() == 0);
lockerByThread0 = false;
listenersReadyMutex.unlock();
}
......@@ -136,9 +136,7 @@ public:
///////////////////////////////////////////////////////////////////////////////////////
explicit SpTasksManager() : ce(nullptr), nbRunningTasks(0), nbPushedTasks(0), nbReadyTasks(0),
lockerByThread0(false){
}
explicit SpTasksManager() : ce(nullptr), nbRunningTasks(0), nbPushedTasks(0), nbReadyTasks(0), nbFinishedTasks(0) {}
// No copy or move
SpTasksManager(const SpTasksManager&) = delete;
......@@ -183,8 +181,8 @@ public:
void addNewTask(SpAbstractTask* newTask){
nbPushedTasks += 1;
insertIfReady(newTask);
nbPushedTasks++;
insertIfReady<false>(newTask);
}
int getNbReadyTasks() const{
......@@ -219,25 +217,47 @@ public:
SpDebugPrint() << "Proceed candidates from after " << t->getId() << ", they are " << candidates.size();
for(auto otherId : candidates){
SpDebugPrint() << "Test " << otherId->getId();
insertIfReady(otherId);
insertIfReady<true>(otherId);
}
t->setState(SpTaskState::FINISHED);
t->releaseControl();
nbRunningTasks--;
// We save all of the following values because the SpTasksManager
// instance might get destroyed as soon as the mutex (mutexFinishedTasks)
// protected region below has been executed.
auto previousCntVal = nbFinishedTasks.fetch_add(1);
auto nbPushedTasksVal = nbPushedTasks.load();
SpComputeEngine *saveCe = ce.load();
{
// In this case the lock on mutexFinishedTasks should be held
// while doing the notify on conditionAllTasksOver
// (conditionAllTasksOver.notify_one()) because we don't want
// the condition variable to get destroyed before we were able
// to notify.
std::unique_lock<std::mutex> locker(mutexFinishedTasks);
tasksFinished.emplace_back(t);
nbRunningTasks -= 1;
// We notify conditionAllTasksOver every time because of
// waitRemain
conditionAllTasksOver.notify_one();
}
t->setState(SpTaskState::FINISHED);
t->releaseControl();
std::unique_lock<std::mutex> locker(mutexFinishedTasks);
conditionAllTasksOver.notify_one();
if(nbPushedTasksVal == (previousCntVal + 1)) {
saveCe->wakeUpWaitingWorkers();
}
}
const SpComputeEngine* getComputeEngine() const {
return ce;
}
bool isFinished() const {
return nbFinishedTasks == nbPushedTasks;
}
};
......
#include "SpAbstractTaskGraph.hpp"
#include "Compute/SpWorker.hpp"
void SpAbstractTaskGraph::finish() {
auto workerForThread = SpWorker::getWorkerForThread();
assert(workerForThread && "workerForThread is nullptr");
workerForThread->doLoop(this);
}
......@@ -12,8 +12,8 @@ protected:
SpTasksManager scheduler;
public:
void setComputeEngine(SpComputeEngine* inCe) {
scheduler.setComputeEngine(inCe);
void computeOn(SpComputeEngine& inCe) {
scheduler.setComputeEngine(std::addressof(inCe));
}
void preTaskExecution(SpAbstractTask* t) {
......@@ -32,6 +32,12 @@ public:
scheduler.waitRemain(windowSize);
}
void finish();
bool isFinished() const {
return scheduler.isFinished();
}
};
#endif
......@@ -290,7 +290,8 @@ private:
auto callableTupleCopy = std::apply([](auto&&... elt) {
return std::make_tuple(std::forward<decltype(elt)>(elt)...);
}, callableTuple);
static_assert(0 < std::tuple_size<decltype(callableTupleCopy)>::value );
using DataDependencyTupleCopyTy = std::remove_reference_t<decltype(dataDependencyTupleCopy)>;
using CallableTupleCopyTy = std::remove_reference_t<decltype(callableTupleCopy)>;
......@@ -980,12 +981,12 @@ private:
const bool isEnabled){
if(isEnabled){
if(!alreadyDone){
assert(SpUtils::GetThreadId() != 0);
specGroupMutex.lock();
}
specGroupPtr->setSpeculationCurrentResult(!taskRes);
if(!alreadyDone){
assert(SpUtils::GetThreadId() != 0);
specGroupMutex.unlock();
}
}
......@@ -1484,13 +1485,19 @@ private:
/// Notify function (called by scheduler when a task is ready to run)
///////////////////////////////////////////////////////////////////////////
void thisTaskIsReady(SpAbstractTask* aTask) final {
void thisTaskIsReady(SpAbstractTask* aTask, const bool isNotCalledInAContextOfTaskCreation) final {
SpGeneralSpecGroup<SpecModel>* specGroup = aTask->getSpecGroup<SpGeneralSpecGroup<SpecModel>>();
SpDebugPrint() << "SpTaskGraph -- thisTaskIsReady -- will test ";
/*
* ATTENTION!
* TO DO :
* We should verify that this double checking doesn't cause any trouble.
*/
if(specGroup && specGroup->isSpeculationNotSet()){
if(specGroup != currentSpecGroup || SpUtils::GetThreadId() != 0){
if(isNotCalledInAContextOfTaskCreation){
specGroupMutex.lock();
}
if(specGroup->isSpeculationNotSet()){
if(specFormula){
if(specFormula(scheduler.getNbReadyTasks(), specGroup->getAllProbability())){
......@@ -1509,9 +1516,11 @@ private:
specGroup->setSpeculationActivation(false);
}
}
if(specGroup != currentSpecGroup || SpUtils::GetThreadId() != 0){
if(isNotCalledInAContextOfTaskCreation){
specGroupMutex.unlock();
}
assert(!specGroup->isSpeculationNotSet());
}
}
......
......@@ -7,6 +7,7 @@
#include <tuple>
#include <unordered_map>
#include <typeinfo>
#include "SpAbstractTask.hpp"
#include "Data/SpDataHandle.hpp"
......
......@@ -71,7 +71,7 @@ class ComputeEngineTest : public UTester< ComputeEngineTest > {
);
}
ce1.addGraph(std::addressof(tg1));
tg1.computeOn(ce1);
mainThreadPromise.get_future().get();
......@@ -81,7 +81,7 @@ class ComputeEngineTest : public UTester< ComputeEngineTest > {
UASSERTEEQUAL(static_cast<int>(workers.size()), 1);
ce2.addGraph(std::addressof(tg2));
tg2.computeOn(ce2);
promises[promises.size()-1].set_value(true);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment