From 93a6c69f64b24dcf22fa42b59c6879490847eae3 Mon Sep 17 00:00:00 2001
From: Mathurn Videau <mathurin.videau@epfedu.fr>
Date: Tue, 21 Sep 2021 18:32:28 +0200
Subject: [PATCH] multiprocessing linux fix

---
 evolve.py                                         | 2 ++
 experiments/ADF.py                                | 2 ++
 experiments/bench.py                              | 3 ++-
 experiments/gp.py                                 | 2 ++
 experiments/imitation_learning/imitation_gp.py    | 1 +
 experiments/imitation_learning/imitation_linGP.py | 1 +
 experiments/linGP.py                              | 3 ++-
 7 files changed, 12 insertions(+), 2 deletions(-)

diff --git a/evolve.py b/evolve.py
index 52c7efe..2ff9ff3 100644
--- a/evolve.py
+++ b/evolve.py
@@ -39,6 +39,8 @@ if "__main__" == __name__:
 
     mstats = factory.get_stats()
 
+    multiprocessing.set_start_method('spawn')
+    
     pool = multiprocessing.Pool(conf["params"]["n_thread"], initializer=factory.init_global_var)
     evoTool.toolbox.register("map", pool.map)
 
diff --git a/experiments/ADF.py b/experiments/ADF.py
index e6b5859..fc4f1b1 100644
--- a/experiments/ADF.py
+++ b/experiments/ADF.py
@@ -188,6 +188,8 @@ if __name__ == "__main__":
     conf = vars(args)
     initializer(conf)
 
+    multiprocessing.set_start_method('spawn')
+
     stats_fit = tools.Statistics(lambda ind: ind.fitness.values[0])
     stats_size = tools.Statistics(len)
     stats_bandit = tools.Statistics(lambda ind: len(ind.fitness.rewards))
diff --git a/experiments/bench.py b/experiments/bench.py
index 07591d1..e7f2593 100644
--- a/experiments/bench.py
+++ b/experiments/bench.py
@@ -148,7 +148,8 @@ if __name__ == "__main__":
     parser.add_argument("--path", help="directory for results", default="", type=str)
 
     args = parser.parse_args()
-
+    multiprocessing.set_start_method('spawn')
+    
     if args.path == "":
         args.path = os.path.join("experiments", "results", "bench", ntpath.basename(args.conf)[:-4])
 
diff --git a/experiments/gp.py b/experiments/gp.py
index 91966d6..7f57183 100644
--- a/experiments/gp.py
+++ b/experiments/gp.py
@@ -170,6 +170,8 @@ if __name__== "__main__":
 
     mstats = factory.get_stats()
 
+    multiprocessing.set_start_method('spawn')
+
     pool = multiprocessing.Pool(args.n_thread, initializer=factory.init_global_var)
     toolbox.register("map", pool.map)
     
diff --git a/experiments/imitation_learning/imitation_gp.py b/experiments/imitation_learning/imitation_gp.py
index a02b077..59fca97 100644
--- a/experiments/imitation_learning/imitation_gp.py
+++ b/experiments/imitation_learning/imitation_gp.py
@@ -192,6 +192,7 @@ if __name__== "__main__":
     stats_size = tools.Statistics(len)
     stats_bandit = tools.Statistics(lambda ind: len(ind.fitness.rewards))
 
+    multiprocessing.set_start_method('spawn')
     mstats = tools.MultiStatistics(fitness=stats_fit, size=stats_size, bandit=stats_bandit)
     mstats.register("avg", lambda x: np.mean(x, axis=0))
     mstats.register("std", lambda x: np.std(x, axis=0))
diff --git a/experiments/imitation_learning/imitation_linGP.py b/experiments/imitation_learning/imitation_linGP.py
index 59fe3d9..2e361f3 100644
--- a/experiments/imitation_learning/imitation_linGP.py
+++ b/experiments/imitation_learning/imitation_linGP.py
@@ -157,6 +157,7 @@ if __name__== "__main__":
         data.extend(MC_collect_single(demonstrator.action, ENV, 1000))
     dataset.core_transition(data)
 
+    multiprocessing.set_start_method('spawn')
     pool = multiprocessing.Pool(12, initializer=initializer, initargs=("AntPyBulletEnv-v0", 0, TimeFeatureWrapper))
     toolbox.register("map", pool.map)
 
diff --git a/experiments/linGP.py b/experiments/linGP.py
index 7d70ef2..fc5ec84 100644
--- a/experiments/linGP.py
+++ b/experiments/linGP.py
@@ -166,7 +166,8 @@ if __name__ == '__main__':
     factory.init_global_var()
 
     mstats = factory.get_stats()
-
+    
+    multiprocessing.set_start_method('spawn')
     pool = multiprocessing.Pool(args.n_thread, initializer=factory.init_global_var)
     toolbox.register("map", pool.map)
     
-- 
GitLab