diff --git a/rl_audio_nav/rl/bin/run_agent_in_env.py b/rl_audio_nav/rl/bin/run_agent_in_env.py index 5c446c664a8bfe9fc94e9689f6550cee93eaf652..a1b7bc55fe05ce46de9c8a5e30fe6f44e389fae6 100644 --- a/rl_audio_nav/rl/bin/run_agent_in_env.py +++ b/rl_audio_nav/rl/bin/run_agent_in_env.py @@ -1,23 +1,26 @@ import logging import os import sys -from typing import Any, Optional, get_args +from typing import Optional -import exputils as eu -import gymnasium as gym +import numpy as np import torch -from torch import Tensor -from torch.distributions import Categorical from rl_audio_nav.rl.agents import Agent, get_action_from_policy +from rl_audio_nav.rl.environments.config_lib import CONFIG_LIB +from rl_audio_nav.rl.environments.map_envs.wer_map_env import RlanWerMapEnv +from rl_audio_nav.rl.environments.rlan_env import DiscreteAction +from rl_audio_nav.rl.utils import load_agent_class -# from rl_audio_nav.rl.agents.env_agents.rlan.ssl_agent import RlanSSLAgent -# from rl_audio_nav.rl.agents.env_agents.rlan.cnn_agent import RlanCnnAgent -from rl_audio_nav.rl.agents.env_agents.rlan.ms_ssl_backbone_agent import RlanMultiSourceSSLAgent -from rl_audio_nav.rl.utils import load_experiment +# from rl_audio_nav.rl.utils import load_rlan_experiment +from rl_audio_nav.utils import get_eu_experiment_path -ObsType = Any -ActType = Any +# from rl_audio_nav.rl.agents.env_agents.rlan.cnn_agent import RlanCnnAgent +# from rl_audio_nav.rl.agents.env_agents.rlan.ssl_agent import RlanSSLAgent +# from rl_audio_nav.rl.agents.env_agents.rlan.ms_ssl_backbone_agent import RlanMultiSourceSSLAgent +# from rl_audio_nav.supervised_localization.single_source.lightning_module import ( +# SSLModuleSingleSource, +# ) @torch.no_grad() @@ -32,76 +35,93 @@ def main( # env_name = 'CartPole-v1' exp_dir_path: str = "" - config: eu.AttrDict + # config: eu.AttrDict # Take the last experiment if not exp_dir_name: exp_dir_name = sorted(os.listdir("output/rl"))[-1] - exp_dir_path, config = load_experiment( - exp_dir_name=exp_dir_name, - ) + print(exp_dir_name) + # exp_dir_path, config = load_experiment( + # exp_dir_name=exp_dir_name, + # ) + env_name = "rlan-wermap-v0" + config = CONFIG_LIB[env_name] + env_config = config.env_config + env: RlanWerMapEnv = RlanWerMapEnv(config=env_config) + # env, agent = load_rlan_experiment(exp_dir_name) + + # env.unwrapped.render_mode = "human" + # else: # config = CONFIG_LIB[env_name] # config.env_name = env_name - # agent: Agent = load_agent_class( - # class_name=config.agent.class_name - # ).from_env( - # env_name=config.env_name, - # env_config=config.get('env_config', None) - # ) + agent: Agent = load_agent_class( + class_name=config.agent.class_name, + ).from_env( + env_name=env_name, + env_config=env_config, + ) # # if exp_dir_path: # agent.load_model_checkpoint(exp_folder_path=exp_dir_path) # agent: Agent = RlanSSLAgent.from_env( + # agent: Agent = RlanMultiSourceSSLAgent.from_env( # agent: Agent = RlanCnnAgent.from_env( - agent: Agent = RlanMultiSourceSSLAgent.from_env( - env_name="rlan-wermap-v0", - env_config=config.env_config, - ) + # env_name="rlan-wermap-v0", + # env_config=config.env_config, + # ) - env_kwargs: dict[str, Any] = {} - if "env_config" in config and config.env_config is not None: - env_kwargs["config"] = config.env_config - - # env: WerMapAbstractEnv = gym.make(config.env_name, **env_kwargs) # type: ignore - # assert isinstance(env.unwrapped, WerMapAbstractEnv) - logger.info("Initializing the environment") - env: gym.Env = gym.make( - id=config.env_name, - render_mode="human", - **env_kwargs, - ) + # ssl_network: SSLModuleSingleSource = SSLModuleSingleSource.load_from_checkpoint( + # checkpoint_path="/exputils/012_ssl_single_source/experiments/experiment_000013/repetition_000000/final_model.ckpt", + # input_shape=(6, 257, 64), + # batch_size=250, + # norm_type="batch", # type: ignore + # predict_dist=False, + # ) + + # env_kwargs: dict[str, Any] = {} + # if "env_config" in config and config.env_config is not None: + # env_kwargs["config"] = config.env_config + + # # env: WerMapAbstractEnv = gym.make(config.env_name, **env_kwargs) # type: ignore + # # assert isinstance(env.unwrapped, WerMapAbstractEnv) + # logger.info("Initializing the environment") + # env: gym.Env = gym.make( + # id=config.env_name, + # render_mode="human", + # **env_kwargs, + # ) logger.info("Env class: %s", env.unwrapped.__class__.__name__) - EnvObsType, EnvActType = get_args(env.unwrapped.__orig_bases__[0]) - logger.info("ObsType: %s", EnvObsType) - logger.info("ActType: %s", EnvActType) - obs: ObsType = env.reset(seed=config.seed)[0] - step: int = 0 + obs = env.reset( + # seed=config.seed, + )[0] + global_step_counter: int = 0 while True: # logger.info("obs = %s", obs) + if env.unwrapped.render_mode: + env.render() - policy: Categorical - value: Tensor with torch.no_grad(): + print("GT:", env.unwrapped.simulator.get_doa() * 180 / np.pi) # type ignore policy, value = agent( torch.tensor(obs).to(torch.float32), ) + # ssl_network.predict_step(batch=obs.unsqueeze(0)) - logger.info("step = %i", step) + logger.info("step = %i", global_step_counter) - print(policy.entropy()) - action: ActType = int( + # print(policy.entropy()) + action: int + action = int( get_action_from_policy( policy=policy, evaluate=True, ), ) - logger.info("a = %i", action) - - # logger.info("action = %s", ACTION_NAMES[action]) - # action: ActType = env.action_space.sample() + action = env.unwrapped.action_space.sample() + logger.info("action = %s", DiscreteAction(action).name) # mic_x, mic_y = env.agent_mic.location[:2] # delta_x, delta_y = ACTIONS[action] * delta @@ -126,13 +146,23 @@ def main( obs, reward, terminated, truncated, info = env.step(action) logger.info("r = %f", reward) + logger.info("orientation = %s", info["agent_orientation"]) if terminated or truncated: - break + print("RESETTING") + obs, _ = env.reset() - step += 1 + global_step_counter += 1 if __name__ == "__main__": - _exp_dir_name: Optional[str] = sys.argv[1] if len(sys.argv) == 2 else None - main(exp_dir_name=_exp_dir_name) + exp_id: int = int(sys.argv[1]) + rep_id: int = int(sys.argv[2]) if len(sys.argv) == 3 else 0 + exp_dir: str = get_eu_experiment_path( + exp_id=exp_id, + rep_id=rep_id, + campaign_name="005_rlan_training", + # campaign_name="004_rlan_wermaps", + ) + # exp_dir = "" + main(exp_dir_name=exp_dir)