Mentions légales du service

Skip to content
Snippets Groups Projects
Commit 6720ac56 authored by LEPAGE Gaetan's avatar LEPAGE Gaetan
Browse files

Environments refactoring

parent 365adb42
Branches
No related tags found
No related merge requests found
# Reinforcement Learning Audio Navigation
## Audio simulator
## Reinforcement Learning
### Environments
- RlanAbstractEnv
- RlanOnlineDecodingAbstract
- RlanOnlineDecoding
- RlanOnlineDecodingVoskVec
- RlanOnlineDecodingSpeechbrainVec
- RlanWerMapEnv
import os
import logging
from abc import ABC, abstractmethod
from typing import Sequence, Union
from abc import ABC
from typing import Sequence
import numpy as np
from torch.distributions import Distribution, Categorical
......
......@@ -10,11 +10,9 @@ This experiment trains Proximal Policy Optimization (PPO) agent Atari Breakout
It runs the [game environments on multiple processes](../game.html) to sample efficiently.
"""
import logging
from typing import Sequence, Union
from typing import Sequence
import exputils as eu
import numpy as np
import torch
from torch import nn
......@@ -22,7 +20,7 @@ from torch import Tensor
from torch.distributions import Categorical, Distribution
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence
from ..environments.rl_audio_nav_env import ACTION_NAMES
from ..environments.abstract_env import ACTION_NAMES
from .agent import Agent, get_action_from_policy
......@@ -41,10 +39,9 @@ class LSTMAgent(Agent):
def __init__(self,
config: eu.AttrDict,
device: Union[str, torch.device] = torch.device('cpu'),
**kwargs) -> None:
super().__init__(device=device)
super().__init__()
self.config: eu.AttrDict = eu.combine_dicts(kwargs,
config,
......
from .rl_audio_nav_env import RlAudioNavEnv
from gym.envs.registration import register
# from .rlan_vec_env_vosk import RlAudioNavVecEnv
register(
id='debug-v0',
entry_point='rl_audio_nav.rl.environments.debug_env:DebugEnv'
)
register(
id='rlan-v0',
entry_point='rl_audio_nav.rl.environments:RlAudioNavEnv'
)
register(
id='rlan-vec-vosk-v0',
entry_point='rl_audio_nav.rl.environments.rlan_vec_env_vosk:RlAudioNavVecEnv'
)
register(
id='rlan-vec-speechbrain-v0',
entry_point='rl_audio_nav.rl.environments.rlan_vec_env_speechbrain:RlAudioNavVecEnv'
)
register(
id='rlan-wermap-env-v0',
entry_point='rl_audio_nav.rl.environments.rlan_wer_map:RlanWerMapEnv'
)
register(
id='rlan-wermap-sync-vec-env-v0',
entry_point='rl_audio_nav.rl.environments.rlan_wer_map:RlanWerMapSyncVecEnv'
)
register(
id='rlan-wermap-async-vec-env-v0',
entry_point='rl_audio_nav.rl.environments.rlan_wer_map:RlanWerMapAsyncVecEnv'
)
register(
id='rlan-wermap-2Dpos-v0',
entry_point='rl_audio_nav.rl.environments.pos_input:PosInputWerMapEnv'
)
def _register_rlan_env(id: str,
filename: str,
classname: str) -> None:
register(id=id,
entry_point=f'rl_audio_nav.rl.environments.{filename}:{classname}')
########
# Misc #
########
_register_rlan_env(id='debug-v0',
filename='debug_env',
classname='DebugEnv')
################################
# Online decoding environments #
################################
_register_rlan_env(id='rlan-online-v0',
filename='online_decoding',
classname='RlAudioNavOnlineDecodingEnv')
_register_rlan_env(id='rlan-online-vosk-vec-v0',
filename='online_decoding_vosk_vec',
classname='RlanOnlineDecodingVoskVecEnv')
_register_rlan_env(id='rlan-online-speechbrain-vec-v0',
filename='online_decoding_speechbrain_vec',
classname='RlanOnlineDecodingSpeechbrainVecEnv')
############
# WER maps #
############
_register_rlan_env(id='rlan-wermap-v0',
filename='wer_map',
classname='RlanWerMapEnv')
_register_rlan_env(id='rlan-wermap-vec-sync-v0',
filename='wer_map',
classname='RlanWerMapSyncVecEnv')
_register_rlan_env(id='rlan-wermap-vec-async-v0',
filename='wer_map',
classname='RlanWerMapAsyncVecEnv')
_register_rlan_env(id='rlan-wermap-2Dpos-v0',
filename='pos_input',
classname='PosInputWerMapEnv')
......@@ -2,21 +2,20 @@
RL Audio navigation environment
"""
from typing import Iterator, Optional
from typing import Iterator
import time
import logging
from pprint import pformat
from math import ceil
import numpy as np
import gym
from gym import spaces
import scipy
import scipy.signal
import exputils as eu
from numpy.random import default_rng # type: ignore
import matplotlib
from ...asr.speech_data_set import speech_data_set_iterator
from ...asr import AsrDecoder, compute_wer
from ...audio_simulator import utils
from ...audio_simulator.audio_signal import SpeechSignal
from ...audio_simulator.room import (Room,
......@@ -48,44 +47,12 @@ ACTION_NAMES: list['str'] = [
NUM_ACTIONS: int = len(ACTION_NAMES)
class StftSpace(gym.Space):
def contains(self, x) -> bool:
return isinstance(x, np.ndarray) and x.ndim == 1
def sample(self) -> np.ndarray:
"""
Warning: a random spectrogram makes no sense.
It can maybe be useful for debugging purposes.
"""
generator: np.random.Generator = default_rng(self.seed()[0])
# Sample a signal
# 12.7s is the average duration of a LibriSpeech sample.
# 3.57s is the standard deviation of the durations.
length: float = generator.normal(loc=12.7, scale=3.57, size=1)[0]
signal: np.ndarray = generator.integers(low=np.iinfo(np.int16).min, # type: ignore
high=np.iinfo(np.int16).max,
dtype=np.int16,
size=length * 16000)
# TODO return the stft, not the signal
# Actually, directly sample the STFT...
return signal
def __eq__(self, o) -> bool:
return isinstance(o, StftSpace)
class RlAudioNavAbstractEnv:
class RlanAbstractEnv:
"""
An environment in which the agent has to move to "hear better" (i.e. to maximize the ASR).
"""
default_config: eu.AttrDict = eu.AttrDict(
# STFT window length
stft_window_length=256,
# Using the room default configuration.
room=eu.combine_dicts(eu.AttrDict(gpu=True),
......@@ -97,13 +64,6 @@ class RlAudioNavAbstractEnv:
# speech_source_position='random',
speech_source_position=[6, 3, 1.5],
# ASR backend
# asr='vosk',
asr='speechbrain',
# Use GPU asr
enable_gpu_asr=True,
# Maximum number of steps per episode
max_number_of_steps=64,
......@@ -113,12 +73,14 @@ class RlAudioNavAbstractEnv:
# TODO: useless ? (should belong to the room I think)
n_channels=1,
stft=eu.AttrDict(
# STFT window length
window_length=256,
),
seed=int(time.time())
)
# Define the observation space
observation_space: gym.Space = StftSpace()
# Define the action space
# self.action_space: spaces.Space = spaces.Box(low=-0.5, high=0.5,
# shape=(2,),
......@@ -131,27 +93,22 @@ class RlAudioNavAbstractEnv:
action_space: gym.Space = spaces.Discrete(NUM_ACTIONS)
def __init__(self,
config: eu.AttrDict = None,
in_thread: bool = False,
**kwargs) -> None:
config: eu.AttrDict = None) -> None:
"""
Args:
config (AttrDict): Environment configuration.
in_thread (bool): Whether this environment is being started within a thread.
"""
self.config: eu.AttrDict = eu.combine_dicts(kwargs,
config,
RlAudioNavAbstractEnv.default_config)
self.config: eu.AttrDict = eu.combine_dicts(config,
RlanAbstractEnv.default_config)
# Logger
self.logger: logging.Logger = logging.getLogger(__name__ + '.' + self.__class__.__name__)
# Seed (might be useless)
self._seed: int = self.config.seed
# STFT window length
self.stft_window_length: int = self.config.stft_window_length
# Room initialization
#######################
# Room initialization #
#######################
self.room: Room
if self.config.room.gpu:
self.logger.info("Instanciating a `GpuRirRoom`")
......@@ -163,9 +120,34 @@ class RlAudioNavAbstractEnv:
self.config.room = self.room.config
self.room.init_grid()
self.agent_mic: Microphone
self.speech_source: Source
self.asr_decoder: AsrDecoder
self.agent_positions: list[np.ndarray] = []
########
# STFT #
########
self.stft: eu.AttrDict = self.config.stft
self.num_signal_samples: int = -1
if self.config.max_obs_duration > 0:
self.num_signal_samples = self.config.max_obs_duration \
* self.room.sampling_frequency
if 'overlap' not in self.stft:
self.stft.overlap = self.stft.window_length // 2
if 'hopsize' not in self.stft:
self.stft.hopsize = self.stft.window_length - self.stft.overlap
num_freq_bins: int = (self.stft.window_length // 2) + 1
len_stft: int = ceil(self.num_signal_samples / self.stft.hopsize) + 1
self.stft.shape = (num_freq_bins, len_stft)
# TODO check that values are indeed between 0 and 1...
# Define the observation space
self.observation_space: gym.Space = gym.spaces.Box(low=0, high=1,
shape=self.stft.shape)
########
# Misc #
########
self.speech_data_set_iterator: Iterator = speech_data_set_iterator()
# The environment keeps track of the number of steps done.
......@@ -175,16 +157,6 @@ class RlAudioNavAbstractEnv:
self.logger.info("Configuration:\n%s", pformat(self.config))
self.agent_mic: Microphone
self.speech_source: Source
self.agent_positions: list[np.ndarray] = []
# UI
# self.user_interface: Optional[ui.Ui] = None
# def _init_ui(self) -> None:
# self.user_interface = ui.Ui()
def get_obs(self,
speech_signal: SpeechSignal = None) -> tuple[SpeechSignal, np.ndarray, np.ndarray]:
"""
......@@ -210,14 +182,14 @@ class RlAudioNavAbstractEnv:
listened_signal: np.ndarray = self.room.get_audio_at_mic(mic_index=0)
# Optionnaly, crop the audio
if self.config.max_obs_duration > 0:
max_num_samples: int = self.config.max_obs_duration * self.room.sampling_frequency
if self.num_signal_samples > 0:
total_num_samples: int = listened_signal.shape[-1]
if total_num_samples > max_num_samples:
if total_num_samples > self.num_signal_samples:
start_index: int = np.random.randint(low=0,
high=total_num_samples - max_num_samples)
listened_signal = listened_signal[start_index:start_index + max_num_samples]
start_index: int = np.random.randint(
low=0,
high=total_num_samples - self.num_signal_samples)
listened_signal = listened_signal[start_index:start_index + self.num_signal_samples]
# Normalize signal and cast it to float 32 before processing STFT
listened_signal = utils.to_float32(signal=listened_signal)
......@@ -227,11 +199,16 @@ class RlAudioNavAbstractEnv:
freqs: np.ndarray
stft: np.ndarray
freqs, times, stft = scipy.signal.stft(x=listened_signal,
fs=self.config.room.sampling_frequency)
fs=self.config.room.sampling_frequency,
nperseg=self.stft.window_length,
noverlap=self.stft.noverlap)
if self.num_signal_samples > 0:
assert stft.shape == self.stft.shape
# self.logger.debug("stft shape: %s", stft.shape)
# Take the modulus of the complex stft coefficients.
stft = np.absolute(stft)
self.logger.debug("Stft shape: %s", stft.shape)
# plot_stft(stft=stft,
# times=times,
......@@ -277,20 +254,6 @@ class RlAudioNavAbstractEnv:
# Reset list of agent positions
self.agent_positions = [self.agent_mic.location]
def _compute_reward(self,
ground_truth: str,
hypothesis: str) -> float:
self.logger.debug("Computing WER")
# Ensure the WER is less than 1 so that reward stays between 0 and 1.
# Actually, it is not a fundamental problem if the reward is negative but a very high WER
# will not be caused by a 'very bad' positionning. Hence, no need to hardly penalyze the
# agent in this case.
reward: float = 1 - min(1, compute_wer(ground_truth=ground_truth,
hypothesis=hypothesis))
assert 0 <= reward <= 1
return reward
def _move_agent_and_get_obs(self,
action: int) -> tuple[SpeechSignal, np.ndarray, np.ndarray]:
......@@ -328,10 +291,6 @@ class RlAudioNavAbstractEnv:
def render(self, mode: str = 'matplotlib') -> None:
# Clear the screen
# time.sleep(1)
# print(chr(27) + "[2J")
if mode == 'ascii':
grid: list[list[str]] = [['o' for _ in range(self.room.n_x)]
......@@ -403,78 +362,3 @@ class RlAudioNavAbstractEnv:
def close(self) -> None:
if hasattr(self, 'figure') and self.figure is not None:
matplotlib.pyplot.close(fig=self.figure)
class RlAudioNavEnv(RlAudioNavAbstractEnv, gym.Env):
"""
An environment in which the agent has to move to "hear better" (i.e. to maximize the ASR).
"""
def __init__(self,
config: eu.AttrDict = None,
init_asr_decoder: bool = False,
in_thread: bool = False) -> None:
RlAudioNavAbstractEnv.__init__(self,
config=config,
in_thread=in_thread)
if init_asr_decoder:
# Initialize the ASR decoder
if self.config.asr == 'speechbrain':
from ...asr.speechbrain import SpeechBrainAsr
self.asr_decoder = SpeechBrainAsr(use_gpu=self.config.enable_gpu_asr)
elif self.config.asr == 'vosk':
from ...asr.vosk import VoskAsr
self.asr_decoder = VoskAsr(use_gpu=self.config.enable_gpu_asr,
in_thread=in_thread)
def reset(self, seed: Optional[int]) -> np.ndarray:
# TODO manage the `seed` argument (right now I don't really know what to do with it)
self._reset()
initial_observation: np.ndarray = self.get_obs()[-1]
return initial_observation
def step(self, action: int) -> tuple[np.ndarray, float, bool, dict]:
"""
Perform a step in the environment:
- Move the microphone according to the given action
- Select a speech sample
- Simulate listened signal and compute its STFT (observation)
- Decode the signal using ASR and compute the reward
Args:
action (np.ndarray): The agent action (a 2D displacement vector).
shape = (2,)
Returns:
observation (np.ndarray): The signal listened by the agent.
reward (float): The reward obtained for this step.
reward = 1 - WER
done (bool): If the episode is over.
info (dict): A dictionnary containing miscellaneous information about the
environment.
"""
self.step_counter += 1
info: dict = {
'step_id': self.step_counter
}
done: bool = False
if self.step_counter >= self.max_number_of_steps:
done = True
info['done_reason'] = f"max_number_of_steps ({self.max_number_of_steps}) reached"
speech_signal, listened_signal, observation = self._move_agent_and_get_obs(action=action)
# Run the ASR model
self.logger.debug("Decoding the signal")
asr_transcript: str = self.asr_decoder.decode(signal=listened_signal)
# Compute WER
reward: float = self._compute_reward(ground_truth=speech_signal.transcript,
hypothesis=asr_transcript)
return observation, reward, done, info
from typing import Optional
import gym
import exputils as eu
import numpy as np
from numpy.random import default_rng # type: ignore
from ...asr import AsrDecoder, compute_wer
from .abstract_env import RlanAbstractEnv
class StftSpace(gym.Space):
"""
Stft space of variable length (i.e. variable sound sample duration).
"""
def contains(self, x) -> bool:
return isinstance(x, np.ndarray) and x.ndim == 1
def sample(self) -> np.ndarray:
"""
Warning: a random spectrogram makes no sense.
It can maybe be useful for debugging purposes.
"""
generator: np.random.Generator = default_rng(self.seed()[0])
# Sample a signal
# 12.7s is the average duration of a LibriSpeech sample.
# 3.57s is the standard deviation of the durations.
length: float = generator.normal(loc=12.7, scale=3.57, size=1)[0]
signal: np.ndarray = generator.integers(low=np.iinfo(np.int16).min, # type: ignore
high=np.iinfo(np.int16).max,
dtype=np.int16,
size=length * 16000)
# TODO return the stft, not the signal
# Actually, directly sample the STFT...
return signal
def __eq__(self, o) -> bool:
return isinstance(o, StftSpace)
class RlanOnlineDecodingAbstractEnv(RlanAbstractEnv):
"""
Environment where each listened audio sample is provided to the ASR decoder.
The obtained transcript is used to compute the WER which leads to the reward.
"""
default_config: eu.AttrDict = eu.AttrDict(
# Use GPU asr
enable_gpu_asr=True,
)
def __init__(self, config: eu.AttrDict = None) -> None:
self.config: eu.AttrDict = eu.combine_dicts(config,
RlanOnlineDecodingAbstractEnv.default_config)
# TODO
self.observation_space: gym.Space = StftSpace()
def _compute_reward(self,
ground_truth: str,
hypothesis: str) -> float:
self.logger.debug("Computing WER")
# Ensure the WER is less than 1 so that reward stays between 0 and 1.
# Actually, it is not a fundamental problem if the reward is negative but a very high WER
# will not be caused by a 'very bad' positionning. Hence, no need to hardly penalyze the
# agent in this case.
reward: float = 1 - min(1, compute_wer(ground_truth=ground_truth,
hypothesis=hypothesis))
assert 0 <= reward <= 1
return reward
class RlanOnlineDecodingEnv(RlanOnlineDecodingAbstractEnv, gym.Env):
"""
Single thread online decoding environment.
"""
default_config: eu.AttrDict = eu.AttrDict(
# ASR backend
# asr='vosk',
asr='speechbrain',
)
def __init__(self,
config: eu.AttrDict = None,
init_asr_decoder: bool = False,
in_thread: bool = False) -> None:
"""
Args:
config (AttrDict): Environment configuration.
in_thread (bool): Whether this environment is being started within a thread.
"""
config = eu.combine_dicts(config, RlanOnlineDecodingEnv.default_config)
RlanOnlineDecodingAbstractEnv.__init__(self,
config=config)
self.asr_decoder: AsrDecoder
if init_asr_decoder:
# Initialize the ASR decoder
if self.config.asr == 'speechbrain':
from ...asr.speechbrain import SpeechBrainAsr
self.asr_decoder = SpeechBrainAsr(use_gpu=self.config.enable_gpu_asr)
elif self.config.asr == 'vosk':
from ...asr.vosk import VoskAsr
self.asr_decoder = VoskAsr(use_gpu=self.config.enable_gpu_asr,
in_thread=in_thread)
def reset(self, seed: Optional[int]) -> np.ndarray:
# TODO manage the `seed` argument (right now I don't really know what to do with it)
self._reset()
initial_observation: np.ndarray = self.get_obs()[-1]
return initial_observation
def step(self, action: int) -> tuple[np.ndarray, float, bool, dict]:
"""
Perform a step in the environment:
- Move the microphone according to the given action
- Select a speech sample
- Simulate listened signal and compute its STFT (observation)
- Decode the signal using ASR and compute the reward
Args:
action (np.ndarray): The agent action (a 2D displacement vector).
shape = (2,)
Returns:
observation (np.ndarray): The signal listened by the agent.
reward (float): The reward obtained for this step.
reward = 1 - WER
done (bool): If the episode is over.
info (dict): A dictionnary containing miscellaneous information about the
environment.
"""
self.step_counter += 1
info: dict = {
'step_id': self.step_counter
}
done: bool = False
if self.step_counter >= self.max_number_of_steps:
done = True
info['done_reason'] = f"max_number_of_steps ({self.max_number_of_steps}) reached"
speech_signal, listened_signal, observation = self._move_agent_and_get_obs(action=action)
# Run the ASR model
self.logger.debug("Decoding the signal")
asr_transcript: str = self.asr_decoder.decode(signal=listened_signal)
# Compute WER
reward: float = self._compute_reward(ground_truth=speech_signal.transcript,
hypothesis=asr_transcript)
return observation, reward, done, info
......@@ -6,27 +6,28 @@ import gym
from ...audio_simulator import SpeechSignal
from ...asr.speechbrain import SpeechBrainAsr
from .rl_audio_nav_env import RlAudioNavAbstractEnv
from .online_decoding import RlanOnlineDecodingAbstractEnv
class RlAudioNavVecEnv(gym.vector.VectorEnv, RlAudioNavAbstractEnv):
class RlanOnlineDecodingSpeechbrainVecEnv(gym.vector.VectorEnv, RlanOnlineDecodingAbstractEnv):
def __init__(self,
num_envs: int,
env_config: eu.AttrDict = RlAudioNavAbstractEnv.default_config,
config: eu.AttrDict = None,
copy: bool = True) -> None:
gym.vector.VectorEnv.__init__(self,
num_envs=num_envs,
observation_space=RlAudioNavAbstractEnv.observation_space,
action_space=RlAudioNavAbstractEnv.action_space)
RlanOnlineDecodingAbstractEnv.__init__(self,
config=config)
RlAudioNavAbstractEnv.__init__(self,
config=env_config,
in_thread=False)
# TODO the spaces definition is not good.
gym.vector.VectorEnv.__init__(
self,
num_envs=num_envs,
observation_space=RlanOnlineDecodingAbstractEnv.observation_space,
action_space=RlanOnlineDecodingAbstractEnv.action_space)
# Initialize the ASR decoder
self.asr_decoder = SpeechBrainAsr(use_gpu=self.config.enable_gpu_asr)
self.asr_decoder: SpeechBrainAsr = SpeechBrainAsr(use_gpu=self.config.enable_gpu_asr)
self.step_counter: int = -1
......@@ -117,7 +118,6 @@ class RlAudioNavVecEnv(gym.vector.VectorEnv, RlAudioNavAbstractEnv):
rewards = [self._compute_reward(ground_truth=speech_signal.transcript,
hypothesis=asr_transcript)
for speech_signal, asr_transcript in zip(speech_signals, transcriptions)]
# rewards = [1 for _ in range(self.num_envs)] # TODO remove
return (deepcopy(self.observations) if self.copy else self.observations,
np.array(rewards),
......@@ -129,7 +129,7 @@ class RlAudioNavVecEnv(gym.vector.VectorEnv, RlAudioNavAbstractEnv):
# pylint: disable=arguments-differ
def close_extras(self, **kwargs) -> None:
# TODO check what we usually do in a single RlAudioNavEnv
# TODO check what we usually do in a single RlanOnlineDecodingEnv
pass
def _check_observation_spaces(self) -> bool:
......
......@@ -17,7 +17,7 @@ from gym.error import (
)
from vosk import SetLogLevel, GpuInit
from .rl_audio_nav_env import RlAudioNavEnv
from .online_decoding import RlanOnlineDecodingAbstractEnv, RlanOnlineDecodingEnv
class AsyncState(Enum):
......@@ -33,8 +33,8 @@ class Worker:
error_queue: Queue) -> None:
# print(f'Hello, I am thread {index}')
self._env: gym.Env = RlAudioNavEnv(config=env_config,
in_thread=True)
self._env: gym.Env = RlanOnlineDecodingEnv(config=env_config,
in_thread=True)
# print('env done init')
self.index: int = index
......@@ -133,24 +133,27 @@ class Worker:
self._input_queue.put(('close', None))
class RlAudioNavVecEnv(gym.vector.VectorEnv):
class RlanOnlineDecodingVoskVecEnv(RlanOnlineDecodingAbstractEnv, gym.vector.VectorEnv):
def __init__(self,
num_envs: int,
env_config: eu.AttrDict = eu.AttrDict(),
config: eu.AttrDict = None,
copy: bool = True) -> None:
self.logger: logging.Logger = logging.getLogger(self.__class__.__name__)
super().__init__(num_envs=num_envs,
observation_space=RlAudioNavEnv.observation_space,
action_space=RlAudioNavEnv.action_space)
RlanOnlineDecodingAbstractEnv.__init__(self,
config=config)
gym.vector.VectorEnv.__init__(num_envs=num_envs,
observation_space=self.observation_space,
action_space=RlanOnlineDecodingAbstractEnv.action_space)
self.observations: list[np.ndarray] = []
self.copy: bool = copy
if env_config.enable_gpu_asr:
if self.config.enable_gpu_asr:
# Initialize GPU for ASR
self.logger.info('Initializing GPU for ASR `GpuInit()`')
SetLogLevel(0)
......@@ -159,7 +162,7 @@ class RlAudioNavVecEnv(gym.vector.VectorEnv):
self.error_queue: Queue = Queue()
self.workers: list[Worker] = [Worker(index=worker_index,
env_config=env_config,
env_config=config,
error_queue=self.error_queue)
for worker_index in range(self.num_envs)]
......
......@@ -12,11 +12,11 @@ import numpy as np
import gym
import exputils as eu
from .rl_audio_nav_env import RlAudioNavEnv, RlAudioNavAbstractEnv
from .abstract_env import RlanAbstractEnv
from ...audio_simulator.wer_map.wer_map import WerMap, is_map_config_compatible
class RlanWerMapEnv(RlAudioNavEnv):
class RlanWerMapEnv(RlanAbstractEnv, gym.Env):
"""
An environment in which the agent has to move to "hear better" (i.e. to maximize the ASR).
"""
......@@ -42,7 +42,8 @@ class RlanWerMapEnv(RlAudioNavEnv):
from pprint import pformat
config = eu.combine_dicts(kwargs, config, RlanWerMapEnv.default_config)
super().__init__(config=config)
RlanAbstractEnv.__init__(self,
config=config)
self.config.wer_map.room = self.config.room
self.logger.info("definitive config for this RlanWerMapEnv:\n%s", pformat(self.config))
......@@ -164,15 +165,16 @@ class RlanWerMapSyncVecEnv(gym.vector.VectorEnv, RlanWerMapEnv):
def __init__(self,
num_envs: int,
env_config: eu.AttrDict,
config: eu.AttrDict,
copy: bool = True) -> None:
RlanWerMapEnv.__init__(self,
config=config)
gym.vector.VectorEnv.__init__(self,
num_envs=num_envs,
observation_space=RlAudioNavAbstractEnv.observation_space,
action_space=RlAudioNavAbstractEnv.observation_space)
RlanWerMapEnv.__init__(self,
config=env_config)
observation_space=self.observation_space,
action_space=RlanAbstractEnv.action_space)
self.step_counter: int = -1
......@@ -267,15 +269,13 @@ class RlanWerMapAsyncVecEnv(gym.vector.AsyncVectorEnv):
def __init__(self,
num_envs: int,
env_config: eu.AttrDict) -> None:
config: eu.AttrDict) -> None:
self.config: eu.AttrDict = eu.combine_dicts(env_config,
RlanWerMapEnv.default_config,
RlAudioNavEnv.default_config,
RlAudioNavAbstractEnv.default_config)
self.config: eu.AttrDict = eu.combine_dicts(config,
RlanWerMapEnv.default_config)
super().__init__(
env_fns=[lambda: RlanWerMapEnv(config=env_config)
env_fns=[lambda: RlanWerMapEnv(config=config)
for _ in range(num_envs)],
shared_memory=False,
context='spawn',
......
......@@ -24,7 +24,7 @@ ENV_NAME: str
# ENV_NAME = 'debug-v0'
# ENV_NAME = 'CartPole-v1'
# ENV_NAME = 'rlan-wermap-2Dpos-v0'
ENV_NAME = 'rlan-wermap-async-vec-env-v0'
ENV_NAME = 'rlan-wermap-vec-async-v0'
N_ENVS: int
N_ENVS = 16
......@@ -122,11 +122,11 @@ def _init_environment(env_name: str,
LOGGER.debug("env_name: %s", env_name)
LOGGER.debug("env_config: %s", pformat(config))
kwargs: dict[str, Any] = {}
# if config is not None:
# kwargs['config'] = config
if config is not None:
kwargs['config'] = config
kwargs['num_envs'] = N_ENVS
kwargs['env_config'] = config
# kwargs['env_config'] = config
# environment: RlAudioNavEnv = RlAudioNavEnv()
# environment: gym.Env = RlanWerMapEnv(config=config)
......
......@@ -48,6 +48,7 @@ RUN_EXPUTILS_SCRIPT=${PROJECT_DIR}/rl_audio_nav/bin/run_exputils.py
# SCRIPT="asr/speechbrain/compute_wer_on_data_set"
# SCRIPT="asr/vosk/compute_wer_on_data_set"
# SCRIPT="audio_simulator/room/test_gpurir_multiprocessing"
# SCRIPT="audio_simulator/stft_dimensions"
SCRIPT="rl/ppo/run"
SCRIPT="$SCRIPT.py"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment