From cf2e714a7d8effc508c7e714d4fed82fb95e2b70 Mon Sep 17 00:00:00 2001
From: AUTERNAUD Alex <alex.auternaud@inria.fr>
Date: Thu, 29 Feb 2024 15:30:14 +0100
Subject: [PATCH] local_path_planner_mpc beginning

---
 src/robot_behavior/CMakeLists.txt             |   1 +
 .../behavior_local_path_planner_mpc_main.py   |  14 +
 .../src/robot_behavior/__init__.py            |   2 +
 .../behavior_local_path_planner_mpc.py        | 392 ++++++++++++++++++
 .../behavior_local_path_planner_mpc_node.py   | 229 ++++++++++
 .../src/robot_behavior/utils.py               |  17 +
 6 files changed, 655 insertions(+)
 create mode 100755 src/robot_behavior/scripts/behavior_local_path_planner_mpc_main.py
 create mode 100755 src/robot_behavior/src/robot_behavior/behavior_local_path_planner_mpc.py
 create mode 100755 src/robot_behavior/src/robot_behavior/behavior_local_path_planner_mpc_node.py

diff --git a/src/robot_behavior/CMakeLists.txt b/src/robot_behavior/CMakeLists.txt
index 066b5a7..2da02ae 100644
--- a/src/robot_behavior/CMakeLists.txt
+++ b/src/robot_behavior/CMakeLists.txt
@@ -184,6 +184,7 @@ catkin_install_python(PROGRAMS
 				 scripts/behavior_look_at_position_action_server_main.py
          scripts/behavior_global_path_planner_main.py
          scripts/behavior_goal_finder_main.py
+         scripts/behavior_local_path_planner_mpc_main.py
                       DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
 )
 
diff --git a/src/robot_behavior/scripts/behavior_local_path_planner_mpc_main.py b/src/robot_behavior/scripts/behavior_local_path_planner_mpc_main.py
new file mode 100755
index 0000000..15b6d72
--- /dev/null
+++ b/src/robot_behavior/scripts/behavior_local_path_planner_mpc_main.py
@@ -0,0 +1,14 @@
+#!/usr/bin/env python3
+import rospy
+import pkg_resources
+import yaml
+from robot_behavior import LocalPathPlannerMPCNode
+
+
+if __name__ == '__main__':
+    # Init node
+    rospy.init_node('LocalPathPlannerMPCNode', log_level=rospy.DEBUG, anonymous=True)
+    controller = LocalPathPlannerMPCNode()
+    controller.run()
+    controller.shutdown()
+    # rospy.spin()
diff --git a/src/robot_behavior/src/robot_behavior/__init__.py b/src/robot_behavior/src/robot_behavior/__init__.py
index d8aa620..d8601e7 100755
--- a/src/robot_behavior/src/robot_behavior/__init__.py
+++ b/src/robot_behavior/src/robot_behavior/__init__.py
@@ -1,7 +1,9 @@
 from .behavior_global_path_planner import GlobalPathPlanner
 from .behavior_goal_finder import GoalFinder
+from .behavior_local_path_planner_mpc import LocalPathPlannerMPC
 from .behavior_generator_node import BehaviorGenerator
 from .behavior_global_path_planner_node import GlobalPathPlannerNode
+from .behavior_local_path_planner_mpc_node import LocalPathPlannerMPCNode
 from .behavior_goal_finder_node import GoalFinderNode
 from .behavior_go_to_body_action_client import GoToBodyActionClient
 from .behavior_go_to_body_action_server import GoToBodyActionServer
diff --git a/src/robot_behavior/src/robot_behavior/behavior_local_path_planner_mpc.py b/src/robot_behavior/src/robot_behavior/behavior_local_path_planner_mpc.py
new file mode 100755
index 0000000..bfcb648
--- /dev/null
+++ b/src/robot_behavior/src/robot_behavior/behavior_local_path_planner_mpc.py
@@ -0,0 +1,392 @@
+# This file, part of Social MPC in the WP6 of the Spring project,
+# is part of a project that has received funding from the 
+# European Union’s Horizon 2020 research and innovation programme
+#  under grant agreement No 871245.
+# 
+# Copyright (C) 2020-2022 by Inria
+# Authors : Alex Auternaud, Timothée Wintz
+# alex.auternaud@inria.fr
+# timothee.wintz@inria.fr
+
+import time
+import scipy
+from scipy.optimize import NonlinearConstraint, Bounds, minimize
+from jax.config import config
+from jax.ops import index, index_update
+from jax.lax import cond, reshape, bitwise_and
+import jax.numpy as np
+import jax
+from jax import grad, jit, vmap, jacfwd, custom_jvp, partial
+import rospy
+from robot_behavior.utils import local_to_global_jax
+
+config.update("jax_enable_x64", True)
+# config.update('jax_disable_jit', True)
+
+# @partial(custom_jvp, nondiff_argnums=[0])
+def interp1d(cost_map, angles):
+    return jax.scipy.ndimage.map_coordinates(cost_map, angles, order=1, mode='wrap')
+
+def rotmat(theta):
+    return np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
+
+def rotmat_inv(theta):
+    return np.array([[np.cos(theta), np.sin(theta)], [-np.sin(theta), np.cos(theta)]])
+
+# @interp1d.defjvp
+# def interp1d_jvp(cost_map, primals, tangents):
+#     ddepth = np.gradient(cost_map)
+#     primals_out = interp1d(cost_map, primals)
+#     tangents_out = interp1d(ddepth, primals) * tangents[0]
+#     return primals_out, tangents_out
+
+
+class LocalPathPlannerMPC:
+    def __init__(self, h, robot_config, horizon, dim_config, u_lb, u_ub,
+                 joints_lb=None, joints_ub=None,
+                 max_acceleration=None,
+                 wall_avoidance_points=None,
+                 max_iter=100,
+                 cost_map_region=None):
+        """
+            MPC Constructor.
+
+            Args:
+                h (float): time step
+                robot (mpc.Robot): robot model
+                u_lb (array): 1D array of length robot.n_joints + 2 lower bound for control
+                u_ub (array): 1D array of length robot.n_joints + 2 upper bound for control
+                joints_lb (array, optional): 1D array of length robot.n_angles lower bound for joint values
+                joints_ub (array, optional): 1D array of length robot.n_angles upper bound for joint values
+                reg_parameter (array or float): 1D array of size robot.n_joints + 2 or float regularization parameter
+        """
+        self.h = h
+        self.robot = robot_config
+        self.dim_config = dim_config
+
+        self.pan_target_dim = self.dim_config['pan_target_dim']
+        self.goto_target_dim = self.dim_config['goto_target_dim']
+        self.human_target_dim = self.dim_config['human_target_dim']
+        self.cost_map_dim = self.dim_config['cost_map_dim']
+        self.weights_dim = self.dim_config['weights_dim']
+        self.loss_coef_dim = self.dim_config['loss_coef_dim']
+        self.loss_rad_dim = self.dim_config['loss_rad_dim']
+
+        n_joints = 0
+        if self.robot.has_pan:
+            n_joints += 1
+        if self.robot.has_tilt:
+            n_joints += 1
+
+        self.action_dim = n_joints + 2
+        self.horizon = horizon
+        self.actions_shape = (self.horizon, self.action_dim)
+        self.n_angles = n_joints
+
+        self.fw_angles = jit(self._fw_angles)
+        self.fw_angles_flat = jit(self._fw_angles_flat)
+        self.g_fw_angles_flat = jit(jacfwd(self._fw_angles_flat, argnums=1))
+
+        self.fw_base_loss_flat = jit(self._fw_base_loss_flat)
+        self.g_fw_base_loss_flat = jit(grad(self._fw_base_loss_flat, 0))
+
+        self.fw_base_loss = jit(self._fw_base_loss)
+        self.fw_base_positions = jit(self._fw_base_positions)
+
+        self.l2_loss = jit(self._l2_loss)
+        self.exp_loss = jit(self._exp_loss)
+        self.maxout_loss = jit(self._maxout_loss)
+        self.minmaxout_loss = jit(self._minmaxout_loss)
+
+        self.regularization_term = jit(self._regularization_term)
+
+        self.wall_avoidance_loss = jit(self._wall_avoidance_loss)
+
+        self.u_lb = u_lb
+        self.u_ub = u_ub
+
+        self.joints_lb = joints_lb
+        self.joints_ub = joints_ub
+
+        self.max_iter = max_iter
+        self.max_acceleration = max_acceleration
+        self.wall_avoidance_points = wall_avoidance_points
+
+        if cost_map_region is not None:
+            self.cost_map_region = cost_map_region
+        else:
+            self.cost_map_region = [[-3, 3], [-3,3]]
+
+        self.fw_base_positions_ang_c = None
+        self.fw_base_positions_pos_c = None
+        self.fw_b_l_f = None
+        self.count = 0
+        self.fw_loss = 0.
+
+
+    def step(self, state, weights, actions, loss_coef, loss_rad, cost_map=None, reg_parameter=None, goto_goal=None, pan_goal=None, human_features=None):
+        """
+            MPC Step.
+
+            Args:
+                state (array): 1D array of the form [alpha_1, alpha_2, x_1, y_1, x_2, y_2] of size n_angles + n_features x n_dim
+                # horizon (int): number of future steps to consider
+                actions (array): (optional) intial value for actions. 2D array of shape n_horizon x (n_angles + 2).
+                    Action order is joint angle velocity, base angle velocity, base linear velocity
+            Returns:
+                array: 2D array of shape n_horizon x (n_angles + 2), optimal actions
+        """
+        assert(len(state) == self.n_angles)
+        assert(actions.shape == self.actions_shape)
+        assert(len(weights) == self.weights_dim)
+        assert(len(loss_coef) == self.loss_coef_dim)
+        assert(len(loss_rad) == self.loss_rad_dim)
+
+        if cost_map is not None:
+            assert(cost_map.shape == self.cost_map_dim)
+        if reg_parameter is not None:
+            assert(len(reg_parameter) == self.action_dim or np.isscalar(reg_parameter))
+        if goto_goal is not None:
+            assert(len(goto_goal) == self.goto_target_dim)
+        if pan_goal is not None:
+            assert(len(pan_goal) == self.pan_target_dim)
+        if human_features is not None:
+            assert(len(human_features) == self.human_target_dim)
+
+        if reg_parameter is None:
+            reg_parameter = 0.
+        if pan_goal is None:
+            pan_goal = np.zeros(self.pan_target_dim)
+        if goto_goal is None:
+            goto_goal = np.zeros(self.goto_target_dim)
+        if human_features is None:
+            human_features = np.zeros(self.human_target_dim)
+        if cost_map is None:
+            cost_map = np.zeros(self.cost_map_dim)
+
+
+        rospy.loginfo('goals : {0}, {2}, {1}'.format(goto_goal, pan_goal, human_features))
+
+        fw_loss = lambda x : self.fw_base_loss_flat(x, state=state, cost_map=cost_map, weights=weights, reg_parameter=reg_parameter, loss_coef=loss_coef, loss_rad=loss_rad, goto_goal=goto_goal, pan_goal=pan_goal, human_features=human_features)
+
+        g_fw_loss = lambda x : self.g_fw_base_loss_flat(x, state=state, cost_map=cost_map, weights=weights, reg_parameter=reg_parameter, loss_coef=loss_coef, loss_rad=loss_rad, goto_goal=goto_goal, pan_goal=pan_goal, human_features=human_features)
+
+        lb = np.zeros(self.actions_shape)
+        lb = index_update(lb, index[:, :], self.u_lb[np.newaxis, :])
+        if self.count == 1:
+            lb = index_update(lb, index[0, :], actions[0, :])
+        if self.count > 1:
+            lb = index_update(lb, index[0, :], actions[1, :])
+        lb = lb.flatten()
+        ub = np.zeros(self.actions_shape)
+        ub = index_update(ub, index[:, :], self.u_ub[np.newaxis, :])
+        if self.count == 1:
+            ub = index_update(ub, index[0, :], actions[0, :])
+        if self.count > 1:
+            ub = index_update(ub, index[0, :], actions[1, :])
+        ub = ub.flatten()
+        bounds = Bounds(lb, ub)
+        if self.joints_lb is not None or self.joints_ub is not None:
+            # rospy.logdebug("using constraints")
+            lb_angles = np.zeros((self.horizon, self.n_angles))
+            lb_angles = index_update(lb_angles, index[:, :], self.joints_lb[np.newaxis, :])
+            ub_angles = np.zeros((self.horizon, self.n_angles))
+            ub_angles = index_update(ub_angles, index[:, :], self.joints_ub[np.newaxis, :])
+            constraints = NonlinearConstraint(lambda x: self.fw_angles_flat(state, x),
+                                              jac=lambda x: np.array(self.g_fw_angles_flat(state, x)),
+                                              lb=lb_angles.flatten(), ub=ub_angles.flatten())
+        else:
+            constraints = []
+        options = {'maxiter': self.max_iter}
+        r = minimize(fw_loss, jac=g_fw_loss, x0=actions.flatten(), bounds=bounds, constraints=constraints, options=options)
+        # rospy.logdebug(r)
+
+        if not r.success and options['maxiter'] > 20:
+            rospy.logwarn("Warning: failed optimization")
+            rospy.logwarn(r)
+
+
+        # self.fw_base_positions_ang_c, self.fw_base_positions_pos_c = self.fw_base_positions(reshape(r.x, (self.horizon, self.n_angles + 2)), 0.)
+        # self.fw_b_l_f = self._fw_base_loss_flat(r.x, state=state, cost_map=cost_map, weights=weights, reg_parameter=reg_parameter, loss_coef=loss_coef, loss_rad=loss_rad, goto_goal=goto_goal, pan_goal=pan_goal, human_features=human_features)
+        self.fw_loss = r.fun
+        # rospy.logdebug("fw loss : {}".format(self.fw_loss))
+        # rospy.logdebug(r.nit, r.nfev, r.njev)
+        # rospy.logdebug(r)
+        self.count += 1
+
+        return reshape(r.x, (self.horizon, self.n_angles + 2))
+
+
+    def _l2_loss(self, state, target):
+        return np.sum((state - target)**2)
+
+    def _minmaxout_loss(self, state, target, acceptance_r, coef):
+        return np.minimum(np.maximum(coef*(np.sum((state - target)**2) - acceptance_r**2), 0.), 1.)
+
+    def _maxout_loss(self, state, target, acceptance_r, coef):
+        return np.maximum(coef*(np.sum((state - target)**2) - acceptance_r**2), 0.)
+
+    def _exp_loss(self, dist):
+        return np.minimum(np.exp(11*(dist - 0.5)), 1.)
+
+    def _fw_angles(self, state, actions):
+        return state[np.newaxis, :self.n_angles] + self.h * np.cumsum(actions[:, :self.n_angles], axis=0)
+
+    def _fw_angles_flat(self, state, x):
+        actions = reshape(x, self.actions_shape)
+        return self.fw_angles(state, actions).flatten()
+
+    def _fw_base_positions(self, actions, start_ang):
+        base_angles = self.h * np.cumsum(actions[:, self.n_angles]) + start_ang
+        velocities = np.stack([actions[:, -1] * np.cos(base_angles), actions[:, -1] * np.sin(base_angles)], axis=1)
+        base_positions = self.h * np.cumsum(velocities, axis=0)
+        return base_angles, base_positions
+
+    def false_fun_costmap(self, positions, cost_map):
+        cost_map_xmin = self.cost_map_region[0][0]
+        cost_map_xmax = self.cost_map_region[0][1]
+        cost_map_ymin = self.cost_map_region[1][0]
+        cost_map_ymax = self.cost_map_region[1][1]
+        cost_map_xres, cost_map_yres = cost_map.shape
+        im_x = (positions[:,0] - cost_map_xmin)/(cost_map_xmax - cost_map_xmin) * cost_map_xres
+        im_y = (cost_map_ymax - positions[:,1])/(cost_map_ymax - cost_map_ymin) * cost_map_yres
+        int_position = np.stack([im_x, im_y], axis=1)
+        costs = jax.scipy.ndimage.map_coordinates(cost_map, np.transpose(int_position), order=1, mode='constant')
+        return np.sum(costs)
+
+    def true_fun_l2_loss(self, x, goal):
+        loss = vmap(self.l2_loss, in_axes=(0, None))(x, goal)
+        return np.sum(loss)
+
+    def true_fun_maxout_loss(self, x, goal, r, coef):
+        loss = vmap(self.maxout_loss, in_axes=(0, 0, None, None))(x, goal, r, coef)
+        return np.sum(loss)
+
+    def true_fun_maxout_loss_static(self, x, goal, r, coef):
+        loss = vmap(self.maxout_loss, in_axes=(0, None, None, None))(x, goal, r, coef)
+        return np.sum(loss)
+
+    def true_fun_minmaxout_loss(self, x, goal, r, coef):
+        loss = vmap(self.minmaxout_loss, in_axes=(0, None, None, None))(x, goal, r, coef)
+        return np.sum(loss)
+
+    def fun_loss(self, goal_features, base_positions, base_angles, active_angle_loss, weight, loss_features):
+        x, y, ang, lin_v, ang_v = goal_features
+        velocities = np.array([0., ang_v, lin_v])
+        actions = np.tile(velocities, (self.horizon, 1))
+        goal_angles, goal_positions = self.fw_base_positions(actions, ang)
+        goal_positions += np.array([x, y])
+        loss = self.true_fun_maxout_loss(base_positions[:, 1], goal_positions[:, 1], weight[0, 1], weight[1, 1])
+        loss += cond(loss_features == 1,
+                    lambda _: self.true_fun_maxout_loss(np.linalg.norm(base_positions - goal_positions, axis=1), np.zeros(self.horizon), weight[0, 0], weight[1, 0]),
+                    lambda _: self.true_fun_maxout_loss(base_positions[:, 0], goal_positions[:, 0], weight[0, 0], weight[1, 0]),
+                    operand=None)  # constraint on x or just on dist
+
+        loss += cond(active_angle_loss,  # bitwise_and(lin_v == 0., ang_v == 0.)
+                     lambda _: self.true_fun_maxout_loss(base_angles, goal_angles, weight[0, 2], weight[1, 2]),
+                     lambda _: 0.,
+                     operand=None)
+        # loss += self.true_fun_maxout_loss(base_angles, goal_angles, weight[0, 2], weight[1, 2])
+        # loss = self.true_fun_l2_loss(base_positions[:, 0], goal_positions[:, 0])*weight[1, 0]
+        # loss += self.true_fun_l2_loss(base_positions[:, 1], goal_positions[:, 1])*weight[1, 1]
+        # loss += cond(bitwise_and(active_angle_loss, bitwise_and(lin_v == 0., ang_v == 0.)),
+        #              lambda _: self.true_fun_l2_loss(base_angles, goal_angles)*weight[1, 2],
+        #              lambda _: 0.,
+        #              operand=None)
+        return loss
+
+    def fun_pan_loss(self, state, actions, goal_features, weight):
+        x, y, lin_v, ang_v = goal_features
+        velocities = np.array([0., ang_v, lin_v])
+        actions_features = np.tile(velocities, (self.horizon, 1))
+        _, goal_positions = self.fw_base_positions(actions_features, 0.)
+        goal_positions += np.array([x, y])
+        pan_goal_fw = np.arctan2(goal_positions[:, 1], goal_positions[:, 0])
+        loss = self.true_fun_maxout_loss(self.fw_angles(state, actions), pan_goal_fw, weight[0], weight[1])
+        # loss = self.true_fun_l2_loss(self.fw_angles(state, actions), pan_goal_fw)*weight[1]
+        return loss
+
+    def _regularization_term(self, actions, reg_parameter):
+        if np.isscalar(reg_parameter):
+            reg_term = np.sum(reg_parameter * actions ** 2)
+        else:
+            reg_term = np.sum(reg_parameter[np.newaxis, :] * actions ** 2)
+        return reg_term
+
+    def _wall_avoidance_loss(self, positions, object_map):
+        wall_avoidance_loss = cond(np.any(object_map),
+                             lambda _: np.sum(vmap(self.false_fun_costmap, in_axes=(0, None))(positions, object_map)),
+                             lambda _: 0.,
+                             operand=None)
+        return wall_avoidance_loss
+
+
+    def _fw_base_loss(self, actions, state, cost_map, weights, reg_parameter, loss_coef, loss_rad, goto_goal, pan_goal, human_features):
+        object_map = cost_map[:, :, 0]
+        social_map = cost_map[:, :, 1]
+        base_angles, base_positions = self.fw_base_positions(actions, 0.)
+
+        # escorted human loss (x, y, ang)
+        weight = np.array([loss_rad[3:-1], loss_coef[3:-1]])
+        # rospy.logdebug('human_features weight : {}'.format(weight))
+        human_features_loss = cond(human_features[-1] > 0.,
+                                   lambda _: self.fun_loss(goal_features=human_features[:5],
+                                                           base_positions=base_positions,
+                                                           base_angles=base_angles,
+                                                           active_angle_loss=human_features[-2],
+                                                           weight=weight,
+                                                           loss_features=human_features[-3]),
+                                   lambda _: 0.,
+                                   operand=None)
+
+        # pan loss (ang)
+        weight = np.array([loss_rad[-1], loss_coef[-1]])
+        # rospy.logdebug('pan weight : {}'.format(weight))
+        pan_angle_loss = cond(pan_goal[-1] > 0.,
+                               lambda _: self.fun_pan_loss(state=state,
+                                                           actions=actions,
+                                                           goal_features=pan_goal[:-1],
+                                                           weight=weight),
+                               lambda _: 0.,
+                               operand=None)
+
+        # base goal loss (x, y, ang)
+        weight = np.array([loss_rad[:3], loss_coef[:3]])
+        # rospy.logdebug('goto weight : {}'.format(weight))
+        goto_goal_loss = cond(goto_goal[-1] > 0.,
+                                   lambda _: self.fun_loss(goal_features=goto_goal[:5],
+                                                           base_positions=base_positions,
+                                                           base_angles=base_angles,
+                                                           active_angle_loss=goto_goal[-2],
+                                                           weight=weight,
+                                                           loss_features=goto_goal[-3]),
+                                   lambda _: 0.,
+                                   operand=None)
+
+        # cost map loss
+        # cost_map = np.maximum(object_map, social_map)
+        cost_map = weights[2] * object_map + weights[3] * social_map
+        cost_map_loss = cond(np.any(cost_map),
+                             lambda _: self.false_fun_costmap(base_positions, cost_map),
+                             lambda _: 0.,
+                             operand=None)
+
+        fw_base_pose = np.stack((base_positions[:, 0], base_positions[:, 1], base_angles), axis=-1)
+        new_positions = vmap(local_to_global_jax, in_axes=(None, 0))(fw_base_pose, self.wall_avoidance_points)
+        wall_avoidance_loss = self.wall_avoidance_loss(new_positions, object_map)
+
+        reg_term = self.regularization_term(actions, reg_parameter)
+
+        # rospy.logdebug('wall_avoidance_loss : {}'.format(wall_avoidance_loss))
+        # rospy.logdebug('cost_map_loss : {}'.format(cost_map_loss))
+        # rospy.logdebug('goto_goal_loss : {}'.format(goto_goal_loss))
+        # rospy.logdebug('human_features_loss : {}'.format(human_features_loss))
+        # rospy.logdebug('regularization_term : {}'.format(reg_term))
+        # rospy.logdebug('pan_angle_loss : {}'.format(pan_angle_loss))
+
+        loss  = weights[0]*goto_goal_loss + weights[4]*cost_map_loss + wall_avoidance_loss*weights[5] + human_features_loss*weights[1] + reg_term + pan_angle_loss
+        return loss
+
+    def _fw_base_loss_flat(self, x, state, cost_map, weights, reg_parameter, loss_coef, loss_rad, goto_goal, pan_goal, human_features):
+        actions = reshape(x, self.actions_shape)
+        return self.fw_base_loss(actions, state, cost_map, weights, reg_parameter, loss_coef, loss_rad, goto_goal, pan_goal, human_features)
diff --git a/src/robot_behavior/src/robot_behavior/behavior_local_path_planner_mpc_node.py b/src/robot_behavior/src/robot_behavior/behavior_local_path_planner_mpc_node.py
new file mode 100755
index 0000000..d63276e
--- /dev/null
+++ b/src/robot_behavior/src/robot_behavior/behavior_local_path_planner_mpc_node.py
@@ -0,0 +1,229 @@
+#!/usr/bin/env python3
+import numpy as np
+# from collections import namedtuple
+import pkg_resources
+import tf
+import sys
+import os
+import yaml
+from multiprocessing import Lock
+import numpy as np
+import rospy
+from nav_msgs.msg import OccupancyGrid
+from robot_behavior.behavior_local_path_planner_mpc import LocalPathPlannerMPC
+from robot_behavior.utils import constraint_angle, local_to_global
+from robot_behavior.behavior_generator_node import RobotState
+from social_mpc.config.config import ControllerConfig, RobotConfig
+import time
+from robot_behavior.ros_4_hri_interface import ROS4HRIInterface 
+from spring_msgs.msg import GoToPosition, GoToEntity, LookAtEntity, LookAtPosition
+
+
+class LocalPathPlannerMPCNode:
+    def __init__(self):
+        os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
+        mpc_config_name = rospy.get_param('~mpc_config_name', 'None')
+        robot_config_name = rospy.get_param('~robot_config_name', 'None')
+        mpc_config_path = None if mpc_config_name == 'None' else mpc_config_name
+        robot_config_path = None if robot_config_name == 'None' else robot_config_name
+        self.mpc_config_path = mpc_config_path
+        self.robot_config_path = robot_config_path
+        # self.namespace_slam = rospy.get_param('~namespace_slam', '/slam')
+        self.map_frame = rospy.get_param('~map_frame', 'map')
+        self.robot_frame = rospy.get_param('~robot_frame', 'base_footprint')
+        self.global_occupancy_map_topic = rospy.get_param('~global_occupancy_map_topic', '/slam/global_occupancy_map')
+        self.local_occupancy_map_topic = rospy.get_param('~local_occupancy_map_topic', '/slam/local_occupancy_map')
+        self.max_humans_world = rospy.get_param('max_humans_world', 20)
+        self.max_groups_world = rospy.get_param('max_groups_world', 10)
+        self.namespace = rospy.get_param('namespace', 'behavior_generator')
+        
+        self.read_robot_config(filename=self.robot_config_path)
+        self.read_config(filename=self.mpc_config_path)
+
+        self.local_map_data = None
+        self.rtabmap_ready = False
+        self.local_map_static = None
+
+        self.init_ros_subscriber_and_publicher()
+
+        self.tf_broadcaster = tf.TransformBroadcaster()  # Publish Transform to look for human
+        self.tf_listener = tf.TransformListener()  # Listen Transform to look for human
+
+        self._check_all_sensors_ready()
+
+        self.init_mpc()
+
+
+    def init_ros_subscriber_and_publicher(self):
+        r''' Initialize the subscribers and publishers '''
+        self._subscribers = []
+        self._publishers = []
+        self._action_server = []
+        self._timers = []
+
+        ### ROS Subscribers
+        self._subscribers.append(rospy.Subscriber(self.local_occupancy_map_topic, OccupancyGrid, callback=self._local_map_callback, queue_size=1))
+        # self._subscribers.append(rospy.Subscriber(self.namespace + '/action/go_to_position', GoToPosition, callback=self._go_to_position_callback, queue_size=1))
+        # self._subscribers.append(rospy.Subscriber(self.namespace + '/action/go_to_body', GoToEntity, callback=self._go_to_body_callback, queue_size=1))
+        # self._subscribers.append(rospy.Subscriber(self.namespace + '/action/go_to_person', GoToEntity, callback=self._go_to_person_callback, queue_size=1))
+        # self._subscribers.append(rospy.Subscriber(self.namespace + '/action/go_to_group', GoToEntity, callback=self._go_to_group_callback, queue_size=1))
+
+        ### ROS Publishers
+        # self._cmd_vel_pub = rospy.Publisher('/cmd_vel', Twist, queue_size=1)
+        # self._publishers.append(self._cmd_vel_pub)
+
+        rospy.loginfo("Initializing the LocalPathPlannerMPCNode")
+
+        # self._check_publishers_connection()
+
+
+    def run(self):
+        r''' Runs the ros wrapper '''
+        init_t = rospy.Time.now()
+        while not rospy.is_shutdown():
+            self.step(init_t)
+
+
+    def step(self, init_time):
+        r''' Step the ros wrapper'''
+        print('MPC')
+        rospy.sleep(2)
+
+
+
+
+    def init_mpc(self):
+        dim_config = {'goto_target_dim': self.controller_config.goto_target_dim,
+                      'human_target_dim': self.controller_config.human_target_dim,
+                      'pan_target_dim': self.controller_config.pan_target_dim,
+                      'cost_map_dim': self.local_static_map.shape[:2] + (2,),
+                      'weights_dim': len(self.controller_config.weights_description),
+                      'loss_coef_dim': self.controller_config.loss_coef.shape[1],
+                      'loss_rad_dim': self.controller_config.loss_rad.shape[1]}
+        fw_horizon = int(self.controller_config.fw_time/self.controller_config.h)
+        joints_lb = np.array([self.robot_config.min_pan_angle])
+        joints_ub = np.array([self.robot_config.max_pan_angle])
+        world_size = self.local_map_size
+        self.mpc = LocalPathPlannerMPC(
+            h=self.controller_config.h,
+            robot_config=self.robot_config,
+            dim_config=dim_config,
+            horizon=fw_horizon,
+            u_lb=np.array(self.controller_config.u_lb),
+            u_ub=np.array(self.controller_config.u_ub),
+            joints_lb=joints_lb,
+            joints_ub=joints_ub,
+            max_acceleration=np.array(
+                self.controller_config.max_acceleration),
+            wall_avoidance_points=self.controller_config.wall_avoidance_points,
+            max_iter=self.controller_config.max_iter_optim,
+            cost_map_region=world_size
+        )
+        self.initial_action = np.zeros(self.mpc.actions_shape)
+        # self.initial_action[:, -1] = 0.1 * config.u_ub[-1]
+        self.actions = np.copy(self.initial_action)
+
+
+    def read_config(self, filename=None):
+        if filename is None:
+            filename = pkg_resources.resource_filename(
+                'social_mpc', 'config/social_mpc.yaml')
+        elif os.path.isfile(filename):
+            self.passed_config_loaded = True
+        else:
+            filename = pkg_resources.resource_filename(
+                'social_mpc', 'config/social_mpc.yaml')
+        config = yaml.load(open(filename), Loader=yaml.FullLoader)
+        self.controller_config = ControllerConfig(config)
+
+        self.goal_finder_enabled = self.controller_config.goal_finder_enabled
+        self.path_planner_enabled = self.controller_config.path_planner_enabled
+        self.update_goals_enabled = self.controller_config.update_goals_enabled
+    
+
+    def read_robot_config(self, filename=None):
+        if filename is None:
+            filename = pkg_resources.resource_filename(
+                'sim2d', 'config/robot.yaml')
+            rospy.logdebug("No filename provided for the robot configuration, basic robot config loaded")
+        elif os.path.isfile(filename):
+            rospy.logdebug("Desired robot config loaded")
+        else:
+            filename = pkg_resources.resource_filename(
+                'sim2d', 'config/robot.yaml')
+            rospy.logdebug("Desired filename for the robot configuration does not exist, basic robot config loaded")
+        config = yaml.load(open(filename), Loader=yaml.FullLoader)
+        self.robot_config = RobotConfig(config)
+
+
+    def _local_map_callback(self, data):
+        self.local_map_data = data
+        self.x_local_map = self.local_map_data.info.origin.position.x
+        self.y_local_map = self.local_map_data.info.origin.position.y
+        self.local_map_width = self.local_map_data.info.width
+        self.local_map_height = self.local_map_data.info.height
+        self.local_map_resolution = self.local_map_data.info.resolution
+        self.local_map_size = [[self.x_local_map, self.x_local_map + self.local_map_width*self.local_map_resolution],[self.y_local_map, self.y_local_map + self.local_map_height*self.local_map_resolution]]
+        self.last_shape_local_map = (self.local_map_height, self.local_map_width)
+        self.local_static_map= (np.asarray(self.local_map_data.data) / 100).reshape(self.last_shape_local_map)
+
+
+    def _check_all_sensors_ready(self):
+        rospy.logdebug("START ALL SENSORS READY")
+        self._check_rtabmap_ready()
+        self._check_local_map_ready()
+        rospy.logdebug("ALL SENSORS READY")
+
+
+    def _check_rtabmap_ready(self):
+        rospy.logdebug("Waiting for rtabmap pose to be READY...")
+        while self.rtabmap_ready is None and not rospy.is_shutdown():
+            try:
+                self.tf_listener.waitForTransform(self.map_frame, self.robot_frame, rospy.Time(0), rospy.Duration(5.0))
+                self.rtabmap_ready = True
+                rospy.logdebug("Current rtabmap pose READY=>")
+
+            except:
+                rospy.logerr("Current rtabmap pose not ready yet, retrying for getting rtabmap pose")
+        return self.rtabmap_ready
+
+
+    def _check_local_map_ready(self):
+        self.local_map_data = None
+        rospy.logdebug("Waiting for {} to be READY...".format(self.local_occupancy_map_topic))
+        while self.local_map_data is None and not rospy.is_shutdown():
+            try:
+                self.local_map_data = rospy.wait_for_message(self.local_occupancy_map_topic, OccupancyGrid, timeout=5.0)
+                rospy.logdebug("Current {} READY=>".format(self.local_occupancy_map_topic))
+
+            except:
+                rospy.logerr("Current {} not ready yet, retrying for getting local map".format(self.local_occupancy_map_topic))
+        return self.local_map_data
+
+
+    def shutdown(self):
+        rospy.loginfo("Stopping the LocalPathPlannerMPCNode")
+        self.close()
+        rospy.loginfo("Killing the LocalPathPlannerMPCNode node")
+
+
+    def close(self):
+        if self._subscribers:
+            for subscriber in self._subscribers:
+                subscriber.unregister()
+
+        if self._publishers:
+            for publisher in self._publishers:
+                if isinstance(publisher, dict):
+                    for pub in publisher.values():
+                        pub.unregister()
+                else:
+                    publisher.unregister()
+        if self._timers:
+            for timer in self._timers:
+                timer.shutdown()
+
+        if self._action_server:
+            self._action_server.shutdown()
+        # self.ros_4_hri_interface.close()
+
diff --git a/src/robot_behavior/src/robot_behavior/utils.py b/src/robot_behavior/src/robot_behavior/utils.py
index ab088a8..55bf1db 100755
--- a/src/robot_behavior/src/robot_behavior/utils.py
+++ b/src/robot_behavior/src/robot_behavior/utils.py
@@ -2,6 +2,8 @@
 import numpy as np
 import tf
 from scipy.interpolate import RectBivariateSpline
+import jax.numpy as jnp
+from jax import grad, jit, vmap, jacfwd, custom_jvp, partial
 
 
 def constraint_angle(angle, min_value=-np.pi, max_value=np.pi):
@@ -21,11 +23,26 @@ def constraint_angle(angle, min_value=-np.pi, max_value=np.pi):
     return new_angle
 
 
+def local_to_global_jax(robot_position, x):
+    angle = rotmat_2d_jax(robot_position[:, -1])
+    y = vmap(vmpa_dot_jax, in_axes=(0, None))(jnp.moveaxis(angle, -1, 0), x)
+    return jnp.array(y + robot_position[:, :2])
+
+
+def vmpa_dot_jax(a, b):
+    return a.dot(b)
+
+
 def rotmat_2d(angle):
     return np.matrix([[np.cos(angle), -np.sin(angle)],
                       [np.sin(angle), np.cos(angle)]])
 
 
+def rotmat_2d_jax(angle):
+    return jnp.array([[jnp.cos(angle), -jnp.sin(angle)],
+                      [jnp.sin(angle), jnp.cos(angle)]])
+
+
 def local_to_global(robot_position, x):
     angle = robot_position[-1]
     y = rotmat_2d(angle).dot(x) + robot_position[:2]
-- 
GitLab