Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 80230437 authored by REINKE Chris's avatar REINKE Chris
Browse files

measurement pipeline: allow to define campaign name, add missing rep numbers;...

measurement pipeline: allow to define campaign name, add missing rep numbers; behavior_generator: adapt to changes in human_aware_navigation_rl
parent 3ad36998
No related branches found
No related tags found
No related merge requests found
...@@ -16,8 +16,10 @@ ...@@ -16,8 +16,10 @@
<arg name="local_occupancy_map_topic" default="/slam/local_map"/> <arg name="local_occupancy_map_topic" default="/slam/local_map"/>
<arg name="diag_timer_rate" default="10"/> <arg name="diag_timer_rate" default="10"/>
<arg name="odometry_topic" default="/mobile_base_controller/odom"/> <arg name="odometry_topic" default="/mobile_base_controller/odom"/>
<arg name="campaign_name" default="campaign"/>
<node pkg="measurement_pipeline" type="measurement_generator_main.py" name="measurement" output="screen" > <node pkg="measurement_pipeline" type="measurement_generator_main.py" name="measurement" output="screen" >
<param name="campaign_name" type="str" value="$(arg campaign_name)"/>
<param name="odometry_topic" type="str" value="$(arg odometry_topic)"/> <param name="odometry_topic" type="str" value="$(arg odometry_topic)"/>
<param name="plot_render" type="bool" value="$(arg plot_render)"/> <param name="plot_render" type="bool" value="$(arg plot_render)"/>
<param name="max_humans_world" type="int" value="$(arg max_humans_world)"/> <param name="max_humans_world" type="int" value="$(arg max_humans_world)"/>
......
...@@ -44,6 +44,7 @@ class MeasurementGenerator: ...@@ -44,6 +44,7 @@ class MeasurementGenerator:
self.namespace_slam = rospy.get_param('~namespace_slam', '/rtabmap') self.namespace_slam = rospy.get_param('~namespace_slam', '/rtabmap')
self.diag_timer_rate = rospy.get_param('~diag_timer_rate', 10) self.diag_timer_rate = rospy.get_param('~diag_timer_rate', 10)
self.start_experiment_id = rospy.get_param('~start_from_experiment_id', 1) self.start_experiment_id = rospy.get_param('~start_from_experiment_id', 1)
self.campaign_name = rospy.get_param('~campaign_name', 'campaign')
self.joint_states_data = None self.joint_states_data = None
...@@ -175,7 +176,7 @@ class MeasurementGenerator: ...@@ -175,7 +176,7 @@ class MeasurementGenerator:
self.time_start = None self.time_start = None
# Check what repetition already exist and set self.repetition_counter (ex : if there are already 5 repetition, set self.repetition_counter = 6) # Check what repetition already exist and set self.repetition_counter (ex : if there are already 5 repetition, set self.repetition_counter = 6)
self.path = self.create_experiment_folder_path(init_path='/home/ros/robot_behavior_ws/campaign/experiments/', counter_exp=self.start_experiment_id) self.path = self.create_experiment_folder_path(init_path='/home/ros/robot_behavior_ws/experiments/'+self.campaign_name+'/experiments/', counter_exp=self.start_experiment_id)
self.repetition_counter = 1 self.repetition_counter = 1
self.flag_repetition_folder_okay = False self.flag_repetition_folder_okay = False
rospy.loginfo("BehaviorGenerator Initialization Ended") rospy.loginfo("BehaviorGenerator Initialization Ended")
...@@ -247,7 +248,6 @@ class MeasurementGenerator: ...@@ -247,7 +248,6 @@ class MeasurementGenerator:
if self.stop_measurement_flag : if self.stop_measurement_flag :
print('Measurements Stoped') print('Measurements Stoped')
self.log_measurement() self.log_measurement()
self.repetition_counter += 1
self.stop_measurement_flag = False self.stop_measurement_flag = False
self.flag_repetition_folder_okay=False self.flag_repetition_folder_okay=False
self.reset_logs() self.reset_logs()
...@@ -261,14 +261,15 @@ class MeasurementGenerator: ...@@ -261,14 +261,15 @@ class MeasurementGenerator:
def set_repetition_folder_path(self): def set_repetition_folder_path(self):
# go through repetition folders from rep 1 upwards until we find an empty one
self.repetition_counter = 1
path = self.path + 'repetition_{:06d}'.format(self.repetition_counter) + '/data/' path = self.path + 'repetition_{:06d}'.format(self.repetition_counter) + '/data/'
if not os.path.exists(path): while os.path.exists(path):
os.makedirs(path) self.repetition_counter += 1
else : path = self.path + 'repetition_{:06d}'.format(self.repetition_counter) + '/data/'
while os.path.exists(path): os.makedirs(path)
self.repetition_counter +=1
path = self.path + 'repetition_{:06d}'.format(self.repetition_counter) + '/data/'
os.makedirs(path)
print('The following repetition will have the given id : {}'.format(path[-2])) print('The following repetition will have the given id : {}'.format(path[-2]))
......
...@@ -10,7 +10,7 @@ from geometry_msgs.msg import Twist, Vector3 ...@@ -10,7 +10,7 @@ from geometry_msgs.msg import Twist, Vector3
from std_msgs.msg import Float32 from std_msgs.msg import Float32
from robot_behavior.utils.drl_local_path_planner import LocalPathPlannerDRL from robot_behavior.utils.drl_local_path_planner import LocalPathPlannerDRL
from robot_behavior.utils.utils_drl_local_path_planner import * from robot_behavior.utils.utils_drl_local_path_planner import *
from human_aware_navigation_rl.agent.sac.sac import SAC from human_aware_navigation_rl.agent.sac.sac_base import SACBase
import exputils as eu import exputils as eu
class RLLocalPathPlannerActionServer: class RLLocalPathPlannerActionServer:
...@@ -102,7 +102,7 @@ class RLLocalPathPlannerActionServer: ...@@ -102,7 +102,7 @@ class RLLocalPathPlannerActionServer:
local_map_depth = 3 local_map_depth = 3
) )
class_agent_name = rospy.get_param('~agent_class', 'SAC') class_agent_name = rospy.get_param('~agent_class', 'SACBase')
# Creation of class attribute # Creation of class attribute
class_agent = eval(class_agent_name) class_agent = eval(class_agent_name)
self.agent = class_agent(env=config_env, import_weight = import_weight, config=config_agent, device = "cpu") self.agent = class_agent(env=config_env, import_weight = import_weight, config=config_agent, device = "cpu")
......
...@@ -2,10 +2,10 @@ import numpy as np ...@@ -2,10 +2,10 @@ import numpy as np
import exputils as eu import exputils as eu
import cv2 import cv2
from robot_behavior.utils.utils_drl_local_path_planner import * from robot_behavior.utils.utils_drl_local_path_planner import *
from human_aware_navigation_rl.agent.sac.sac import SAC from human_aware_navigation_rl.agent.sac.sac_base import SACBase
from human_aware_navigation_rl.agent.ddqn.ddqn import DDQN from human_aware_navigation_rl.agent.ddqn.ddqn import DDQN
from human_aware_navigation_rl.utils.distance_grid_from_occupancy_grid import raycast_from_occupancy_grid from human_aware_navigation_rl.env.utils.distance_grid_from_occupancy_grid import raycast_from_occupancy_grid
from human_aware_navigation_rl.utils.distance_grid_from_occupancy_grid import distance_and_angle_to_nearest_obstacle_from_occupancy_grid from human_aware_navigation_rl.env.utils.distance_grid_from_occupancy_grid import distance_and_angle_to_nearest_obstacle_from_occupancy_grid
class LocalPathPlannerDRL: class LocalPathPlannerDRL:
r''' LocalPathPlanner based of Deep Reinforcement Learning trained agent ''' r''' LocalPathPlanner based of Deep Reinforcement Learning trained agent '''
...@@ -131,7 +131,7 @@ class LocalPathPlannerDRL: ...@@ -131,7 +131,7 @@ class LocalPathPlannerDRL:
if self.flag_collision: if self.flag_collision:
print('Warning !!!!! The robot is close to an obstacle') print('Warning !!!!! The robot is close to an obstacle')
if isinstance(agent, SAC) : if isinstance(agent, SACBase) :
# Continuous actions # Continuous actions
prediction = agent.predict_action(observation) prediction = agent.predict_action(observation)
...@@ -140,7 +140,7 @@ class LocalPathPlannerDRL: ...@@ -140,7 +140,7 @@ class LocalPathPlannerDRL:
action_id = agent.predict_action(observation) action_id = agent.predict_action(observation)
prediction = self.discrete_action_list[action_id] prediction = self.discrete_action_list[action_id]
else : else :
raise ValueError('Problem with agent, it has to be a DDQN or a SAC') raise ValueError('Problem with agent, it has to be a DDQN or a SACBase')
if self.distance_to_goal > 0.3: if self.distance_to_goal > 0.3:
action = prediction action = prediction
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment