Skip to content

Commit 1246db1

Browse files
committed
[BugFix] Account for composite actions in gym
ghstack-source-id: c09b59904a89d45fa24a61a5e8a24fe307320794 Pull Request resolved: #2718
1 parent 319bb68 commit 1246db1

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

test/test_libs.py

+73-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import importlib.util
88
import urllib.error
99

10+
from gym.core import ObsType
11+
1012
_has_isaac = importlib.util.find_spec("isaacgym") is not None
1113

1214
if _has_isaac:
@@ -23,7 +25,7 @@
2325
from contextlib import nullcontext
2426
from pathlib import Path
2527
from sys import platform
26-
from typing import Optional, Union
28+
from typing import Optional, Tuple, Union
2729
from unittest import mock
2830

2931
import numpy as np
@@ -634,6 +636,76 @@ def test_torchrl_to_gym(self, backend, numpy):
634636
finally:
635637
set_gym_backend(gb).set()
636638

639+
@implement_for("gym", None, "0.26")
640+
def test_gym_dict_action_space(self):
641+
pytest.skip("tested for gym > 0.26 - no backward issue")
642+
643+
@implement_for("gym", "0.26", None)
644+
def test_gym_dict_action_space(self): # noqa: F811
645+
import gym
646+
from gym import Env
647+
648+
class CompositeActionEnv(Env):
649+
def __init__(self):
650+
self.action_space = gym.spaces.Dict(
651+
a0=gym.spaces.Discrete(2), a1=gym.spaces.Box(-1, 1)
652+
)
653+
self.observation_space = gym.spaces.Box(-1, 1)
654+
655+
def step(self, action):
656+
return (0.5, 0.0, False, False, {})
657+
658+
def reset(
659+
self,
660+
*,
661+
seed: Optional[int] = None,
662+
options: Optional[dict] = None,
663+
) -> Tuple[ObsType, dict]:
664+
return (0.0, {})
665+
666+
env = CompositeActionEnv()
667+
torchrl_env = GymWrapper(env)
668+
assert isinstance(torchrl_env.action_spec, Composite)
669+
assert len(torchrl_env.action_keys) == 2
670+
r = torchrl_env.rollout(10)
671+
assert isinstance(r[0]["a0"], torch.Tensor)
672+
assert isinstance(r[0]["a1"], torch.Tensor)
673+
assert r[0]["observation"] == 0
674+
assert r[1]["observation"] == 0.5
675+
676+
@implement_for("gymnasium")
677+
def test_gym_dict_action_space(self): # noqa: F811
678+
import gymnasium as gym
679+
from gymnasium import Env
680+
681+
class CompositeActionEnv(Env):
682+
def __init__(self):
683+
self.action_space = gym.spaces.Dict(
684+
a0=gym.spaces.Discrete(2), a1=gym.spaces.Box(-1, 1)
685+
)
686+
self.observation_space = gym.spaces.Box(-1, 1)
687+
688+
def step(self, action):
689+
return (0.5, 0.0, False, False, {})
690+
691+
def reset(
692+
self,
693+
*,
694+
seed: Optional[int] = None,
695+
options: Optional[dict] = None,
696+
) -> Tuple[ObsType, dict]:
697+
return (0.0, {})
698+
699+
env = CompositeActionEnv()
700+
torchrl_env = GymWrapper(env)
701+
assert isinstance(torchrl_env.action_spec, Composite)
702+
assert len(torchrl_env.action_keys) == 2
703+
r = torchrl_env.rollout(10)
704+
assert isinstance(r[0]["a0"], torch.Tensor)
705+
assert isinstance(r[0]["a1"], torch.Tensor)
706+
assert r[0]["observation"] == 0
707+
assert r[1]["observation"] == 0.5
708+
637709
@pytest.mark.parametrize(
638710
"env_name",
639711
[

torchrl/envs/gym_like.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,10 @@ def read_obs(
292292
return observations
293293

294294
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
295-
action = tensordict.get(self.action_key)
295+
if len(self.action_keys) == 1:
296+
action = tensordict.get(self.action_key)
297+
else:
298+
action = tensordict.select(*self.action_keys).to_dict()
296299
if self._convert_actions_to_numpy:
297300
action = self.read_action(action)
298301

0 commit comments

Comments
 (0)