|
7 | 7 | import importlib.util
|
8 | 8 | import urllib.error
|
9 | 9 |
|
| 10 | +from gym.core import ObsType |
| 11 | + |
10 | 12 | _has_isaac = importlib.util.find_spec("isaacgym") is not None
|
11 | 13 |
|
12 | 14 | if _has_isaac:
|
|
23 | 25 | from contextlib import nullcontext
|
24 | 26 | from pathlib import Path
|
25 | 27 | from sys import platform
|
26 |
| -from typing import Optional, Union |
| 28 | +from typing import Optional, Tuple, Union |
27 | 29 | from unittest import mock
|
28 | 30 |
|
29 | 31 | import numpy as np
|
@@ -634,6 +636,76 @@ def test_torchrl_to_gym(self, backend, numpy):
|
634 | 636 | finally:
|
635 | 637 | set_gym_backend(gb).set()
|
636 | 638 |
|
| 639 | + @implement_for("gym", None, "0.26") |
| 640 | + def test_gym_dict_action_space(self): |
| 641 | + pytest.skip("tested for gym > 0.26 - no backward issue") |
| 642 | + |
| 643 | + @implement_for("gym", "0.26", None) |
| 644 | + def test_gym_dict_action_space(self): # noqa: F811 |
| 645 | + import gym |
| 646 | + from gym import Env |
| 647 | + |
| 648 | + class CompositeActionEnv(Env): |
| 649 | + def __init__(self): |
| 650 | + self.action_space = gym.spaces.Dict( |
| 651 | + a0=gym.spaces.Discrete(2), a1=gym.spaces.Box(-1, 1) |
| 652 | + ) |
| 653 | + self.observation_space = gym.spaces.Box(-1, 1) |
| 654 | + |
| 655 | + def step(self, action): |
| 656 | + return (0.5, 0.0, False, False, {}) |
| 657 | + |
| 658 | + def reset( |
| 659 | + self, |
| 660 | + *, |
| 661 | + seed: Optional[int] = None, |
| 662 | + options: Optional[dict] = None, |
| 663 | + ) -> Tuple[ObsType, dict]: |
| 664 | + return (0.0, {}) |
| 665 | + |
| 666 | + env = CompositeActionEnv() |
| 667 | + torchrl_env = GymWrapper(env) |
| 668 | + assert isinstance(torchrl_env.action_spec, Composite) |
| 669 | + assert len(torchrl_env.action_keys) == 2 |
| 670 | + r = torchrl_env.rollout(10) |
| 671 | + assert isinstance(r[0]["a0"], torch.Tensor) |
| 672 | + assert isinstance(r[0]["a1"], torch.Tensor) |
| 673 | + assert r[0]["observation"] == 0 |
| 674 | + assert r[1]["observation"] == 0.5 |
| 675 | + |
| 676 | + @implement_for("gymnasium") |
| 677 | + def test_gym_dict_action_space(self): # noqa: F811 |
| 678 | + import gymnasium as gym |
| 679 | + from gymnasium import Env |
| 680 | + |
| 681 | + class CompositeActionEnv(Env): |
| 682 | + def __init__(self): |
| 683 | + self.action_space = gym.spaces.Dict( |
| 684 | + a0=gym.spaces.Discrete(2), a1=gym.spaces.Box(-1, 1) |
| 685 | + ) |
| 686 | + self.observation_space = gym.spaces.Box(-1, 1) |
| 687 | + |
| 688 | + def step(self, action): |
| 689 | + return (0.5, 0.0, False, False, {}) |
| 690 | + |
| 691 | + def reset( |
| 692 | + self, |
| 693 | + *, |
| 694 | + seed: Optional[int] = None, |
| 695 | + options: Optional[dict] = None, |
| 696 | + ) -> Tuple[ObsType, dict]: |
| 697 | + return (0.0, {}) |
| 698 | + |
| 699 | + env = CompositeActionEnv() |
| 700 | + torchrl_env = GymWrapper(env) |
| 701 | + assert isinstance(torchrl_env.action_spec, Composite) |
| 702 | + assert len(torchrl_env.action_keys) == 2 |
| 703 | + r = torchrl_env.rollout(10) |
| 704 | + assert isinstance(r[0]["a0"], torch.Tensor) |
| 705 | + assert isinstance(r[0]["a1"], torch.Tensor) |
| 706 | + assert r[0]["observation"] == 0 |
| 707 | + assert r[1]["observation"] == 0.5 |
| 708 | + |
637 | 709 | @pytest.mark.parametrize(
|
638 | 710 | "env_name",
|
639 | 711 | [
|
|
0 commit comments