From fdfb8a54c638d5bec95d1ae3555ec7008ad93d1b Mon Sep 17 00:00:00 2001 From: Toni-SM Date: Thu, 16 Jan 2025 17:05:35 -0500 Subject: [PATCH] Allow the use of the deterministic/stochastic actions during evaluation (#252) --- CHANGELOG.md | 3 +++ skrl/trainers/jax/base.py | 11 +++++++++-- skrl/trainers/jax/sequential.py | 11 ++++++++--- skrl/trainers/jax/step.py | 11 ++++++++--- skrl/trainers/torch/base.py | 11 +++++++++-- skrl/trainers/torch/parallel.py | 16 +++++++++++++--- skrl/trainers/torch/sequential.py | 11 ++++++++--- skrl/trainers/torch/step.py | 11 ++++++++--- 8 files changed, 66 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e70b3d47..8f09793b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Support for automatic mixed precision training in PyTorch - `init_state_dict` method to initialize model's lazy modules in PyTorch - Model instantiators `fixed_log_std` parameter to define immutable log standard deviations +- Define the `stochastic_evaluation` trainer config to allow the use of the actions returned by the agent's model + as-is instead of deterministic actions (mean-actions in Gaussian-based models) during evaluation. + Make the return of deterministic actions the default behavior. ### Changed - Call agent's `pre_interaction` method during evaluation diff --git a/skrl/trainers/jax/base.py b/skrl/trainers/jax/base.py index 2e086820..7320dfc2 100644 --- a/skrl/trainers/jax/base.py +++ b/skrl/trainers/jax/base.py @@ -63,6 +63,7 @@ def __init__( self.disable_progressbar = self.cfg.get("disable_progressbar", False) self.close_environment_at_exit = self.cfg.get("close_environment_at_exit", True) self.environment_info = self.cfg.get("environment_info", "episode") + self.stochastic_evaluation = self.cfg.get("stochastic_evaluation", False) self.initial_timestep = 0 @@ -248,7 +249,8 @@ def single_agent_eval(self) -> None: with contextlib.nullcontext(): # compute actions - actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0] + outputs = self.agents.act(states, timestep=timestep, timesteps=self.timesteps) + actions = outputs[0] if self.stochastic_evaluation else outputs[-1].get("mean_actions", outputs[0]) # step the environments next_states, rewards, terminated, truncated, infos = self.env.step(actions) @@ -375,7 +377,12 @@ def multi_agent_eval(self) -> None: with contextlib.nullcontext(): # compute actions - actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0] + outputs = self.agents.act(states, timestep=timestep, timesteps=self.timesteps) + actions = ( + outputs[0] + if self.stochastic_evaluation + else {k: outputs[-1][k].get("mean_actions", outputs[0][k]) for k in outputs[-1]} + ) # step the environments next_states, rewards, terminated, truncated, infos = self.env.step(actions) diff --git a/skrl/trainers/jax/sequential.py b/skrl/trainers/jax/sequential.py index 5e690d91..12fda411 100644 --- a/skrl/trainers/jax/sequential.py +++ b/skrl/trainers/jax/sequential.py @@ -19,7 +19,8 @@ "headless": False, # whether to use headless mode (no rendering) "disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY "close_environment_at_exit": True, # whether to close the environment on normal program termination - "environment_info": "episode", # key used to get and log environment info + "environment_info": "episode", # key used to get and log environment info + "stochastic_evaluation": False, # whether to use actions rather than (deterministic) mean actions during evaluation } # [end-config-dict-jax] # fmt: on @@ -181,10 +182,14 @@ def eval(self) -> None: with contextlib.nullcontext(): # compute actions + outputs = [ + agent.act(states[scope[0] : scope[1]], timestep=timestep, timesteps=self.timesteps) + for agent, scope in zip(self.agents, self.agents_scope) + ] actions = jnp.vstack( [ - agent.act(states[scope[0] : scope[1]], timestep=timestep, timesteps=self.timesteps)[0] - for agent, scope in zip(self.agents, self.agents_scope) + output[0] if self.stochastic_evaluation else output[-1].get("mean_actions", output[0]) + for output in outputs ] ) diff --git a/skrl/trainers/jax/step.py b/skrl/trainers/jax/step.py index 0bb3b361..689948b3 100644 --- a/skrl/trainers/jax/step.py +++ b/skrl/trainers/jax/step.py @@ -21,7 +21,8 @@ "headless": False, # whether to use headless mode (no rendering) "disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY "close_environment_at_exit": True, # whether to close the environment on normal program termination - "environment_info": "episode", # key used to get and log environment info + "environment_info": "episode", # key used to get and log environment info + "stochastic_evaluation": False, # whether to use actions rather than (deterministic) mean actions during evaluation } # [end-config-dict-jax] # fmt: on @@ -216,10 +217,14 @@ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) with contextlib.nullcontext(): # compute actions + outputs = [ + agent.act(self.states[scope[0] : scope[1]], timestep=timestep, timesteps=timesteps) + for agent, scope in zip(self.agents, self.agents_scope) + ] actions = jnp.vstack( [ - agent.act(self.states[scope[0] : scope[1]], timestep=timestep, timesteps=timesteps)[0] - for agent, scope in zip(self.agents, self.agents_scope) + output[0] if self.stochastic_evaluation else output[-1].get("mean_actions", output[0]) + for output in outputs ] ) diff --git a/skrl/trainers/torch/base.py b/skrl/trainers/torch/base.py index 16b61161..6c7ef417 100644 --- a/skrl/trainers/torch/base.py +++ b/skrl/trainers/torch/base.py @@ -64,6 +64,7 @@ def __init__( self.disable_progressbar = self.cfg.get("disable_progressbar", False) self.close_environment_at_exit = self.cfg.get("close_environment_at_exit", True) self.environment_info = self.cfg.get("environment_info", "episode") + self.stochastic_evaluation = self.cfg.get("stochastic_evaluation", False) self.initial_timestep = 0 @@ -255,7 +256,8 @@ def single_agent_eval(self) -> None: with torch.no_grad(): # compute actions - actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0] + outputs = self.agents.act(states, timestep=timestep, timesteps=self.timesteps) + actions = outputs[0] if self.stochastic_evaluation else outputs[-1].get("mean_actions", outputs[0]) # step the environments next_states, rewards, terminated, truncated, infos = self.env.step(actions) @@ -394,7 +396,12 @@ def multi_agent_eval(self) -> None: with torch.no_grad(): # compute actions - actions = self.agents.act(states, timestep=timestep, timesteps=self.timesteps)[0] + outputs = self.agents.act(states, timestep=timestep, timesteps=self.timesteps) + actions = ( + outputs[0] + if self.stochastic_evaluation + else {k: outputs[-1][k].get("mean_actions", outputs[0][k]) for k in outputs[-1]} + ) # step the environments next_states, rewards, terminated, truncated, infos = self.env.step(actions) diff --git a/skrl/trainers/torch/parallel.py b/skrl/trainers/torch/parallel.py index b409ae7f..c0a6f6a5 100644 --- a/skrl/trainers/torch/parallel.py +++ b/skrl/trainers/torch/parallel.py @@ -19,7 +19,8 @@ "headless": False, # whether to use headless mode (no rendering) "disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY "close_environment_at_exit": True, # whether to close the environment on normal program termination - "environment_info": "episode", # key used to get and log environment info + "environment_info": "episode", # key used to get and log environment info + "stochastic_evaluation": False, # whether to use actions rather than (deterministic) mean actions during evaluation } # [end-config-dict-torch] # fmt: on @@ -65,7 +66,9 @@ def fn_processor(process_index, *args): elif task == "act": _states = queue.get()[scope[0] : scope[1]] with torch.no_grad(): - _actions = agent.act(_states, timestep=msg["timestep"], timesteps=msg["timesteps"])[0] + stochastic_evaluation = msg["stochastic_evaluation"] + _outputs = agent.act(_states, timestep=msg["timestep"], timesteps=msg["timesteps"]) + _actions = _outputs[0] if stochastic_evaluation else _outputs[-1].get("mean_actions", _outputs[0]) if not _actions.is_cuda: _actions.share_memory_() queue.put(_actions) @@ -363,7 +366,14 @@ def eval(self) -> None: # compute actions with torch.no_grad(): for pipe, queue in zip(producer_pipes, queues): - pipe.send({"task": "act", "timestep": timestep, "timesteps": self.timesteps}) + pipe.send( + { + "task": "act", + "timestep": timestep, + "timesteps": self.timesteps, + "stochastic_evaluation": self.stochastic_evaluation, + } + ) queue.put(states) barrier.wait() diff --git a/skrl/trainers/torch/sequential.py b/skrl/trainers/torch/sequential.py index 8304b1fd..913bf3cc 100644 --- a/skrl/trainers/torch/sequential.py +++ b/skrl/trainers/torch/sequential.py @@ -18,7 +18,8 @@ "headless": False, # whether to use headless mode (no rendering) "disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY "close_environment_at_exit": True, # whether to close the environment on normal program termination - "environment_info": "episode", # key used to get and log environment info + "environment_info": "episode", # key used to get and log environment info + "stochastic_evaluation": False, # whether to use actions rather than (deterministic) mean actions during evaluation } # [end-config-dict-torch] # fmt: on @@ -187,10 +188,14 @@ def eval(self) -> None: with torch.no_grad(): # compute actions + outputs = [ + agent.act(states[scope[0] : scope[1]], timestep=timestep, timesteps=self.timesteps) + for agent, scope in zip(self.agents, self.agents_scope) + ] actions = torch.vstack( [ - agent.act(states[scope[0] : scope[1]], timestep=timestep, timesteps=self.timesteps)[0] - for agent, scope in zip(self.agents, self.agents_scope) + output[0] if self.stochastic_evaluation else output[-1].get("mean_actions", output[0]) + for output in outputs ] ) diff --git a/skrl/trainers/torch/step.py b/skrl/trainers/torch/step.py index 77987598..2744be10 100644 --- a/skrl/trainers/torch/step.py +++ b/skrl/trainers/torch/step.py @@ -18,7 +18,8 @@ "headless": False, # whether to use headless mode (no rendering) "disable_progressbar": False, # whether to disable the progressbar. If None, disable on non-TTY "close_environment_at_exit": True, # whether to close the environment on normal program termination - "environment_info": "episode", # key used to get and log environment info + "environment_info": "episode", # key used to get and log environment info + "stochastic_evaluation": False, # whether to use actions rather than (deterministic) mean actions during evaluation } # [end-config-dict-torch] # fmt: on @@ -212,10 +213,14 @@ def eval( with torch.no_grad(): # compute actions + outputs = [ + agent.act(self.states[scope[0] : scope[1]], timestep=timestep, timesteps=timesteps) + for agent, scope in zip(self.agents, self.agents_scope) + ] actions = torch.vstack( [ - agent.act(self.states[scope[0] : scope[1]], timestep=timestep, timesteps=timesteps)[0] - for agent, scope in zip(self.agents, self.agents_scope) + output[0] if self.stochastic_evaluation else output[-1].get("mean_actions", output[0]) + for output in outputs ] )