Skip to content

Commit 2f8bc28

Browse files
committed
Update
[ghstack-poisoned]
2 parents 5816aeb + 6a38fae commit 2f8bc28

34 files changed

+1551
-695
lines changed

docs/source/reference/envs.rst

+71
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,8 @@ The inverse process is executed with the output tensordict, where the `in_keys`
865865

866866
Rename transform logic
867867

868+
.. note:: During a call to `inv`, the transforms are executed in reversed order (compared to the forward / step mode).
869+
868870
Transforming Tensors and Specs
869871
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
870872

@@ -900,6 +902,74 @@ tensor that should not be generated when using :meth:`~torchrl.envs.EnvBase.rand
900902
environment. Instead, `"action_discrete"` should be generated, and its continuous counterpart obtained from the
901903
transform. Therefore, the user should see the `"action_discrete"` entry being exposed, but not `"action"`.
902904

905+
Designing your own Transform
906+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
907+
908+
To create a basic, custom transform, you need to subclass the `Transform` class and implement the
909+
:meth:`~torchrl.envs._apply_transform` method. Here's an example of a simple transform that adds 1 to the observation
910+
tensor:
911+
912+
>>> class AddOneToObs(Transform):
913+
... """A transform that adds 1 to the observation tensor."""
914+
...
915+
... def __init__(self):
916+
... super().__init__(in_keys=["observation"], out_keys=["observation"])
917+
...
918+
... def _apply_transform(self, obs: torch.Tensor) -> torch.Tensor:
919+
... return obs + 1
920+
921+
922+
Tips for subclassing `Transform`
923+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
924+
925+
There are various ways of subclassing a transform. The things to take into considerations are:
926+
927+
- Is the transform identical for each tensor / item being transformed? Use
928+
:meth:`~torchrl.envs.Transform._apply_transform` and :meth:`~torchrl.envs.Transform._inv_apply_transform`.
929+
- The transform needs access to the input data to env.step as well as output? Rewrite
930+
:meth:`~torchrl.envs.Transform._step`.
931+
Otherwise, rewrite :meth:`~torchrl.envs.Transform._call` (or :meth:`~torchrl.envs.Transform._inv_call`).
932+
- Is the transform to be used within a replay buffer? Overwrite :meth:`~torchrl.envs.Transform.forward`,
933+
:meth:`~torchrl.envs.Transform.inv`, :meth:`~torchrl.envs.Transform._apply_transform` or
934+
:meth:`~torchrl.envs.Transform._inv_apply_transform`.
935+
- Within a transform, you can access (and make calls to) the parent environment using
936+
:attr:`~torchrl.envs.Transform.parent` (the base env + all transforms till this one) or
937+
:meth:`~torchrl.envs.Transform.container` (The object that encapsulates the transform).
938+
- Don't forget to edits the specs if needed: top level: :meth:`~torchrl.envs.Transform.transform_output_spec`,
939+
:meth:`~torchrl.envs.Transform.transform_input_spec`.
940+
Leaf level: :meth:`~torchrl.envs.Transform.transform_observation_spec`,
941+
:meth:`~torchrl.envs.Transform.transform_action_spec`, :meth:`~torchrl.envs.Transform.transform_state_spec`,
942+
:meth:`~torchrl.envs.Transform.transform_reward_spec` and
943+
:meth:`~torchrl.envs.Transform.transform_reward_spec`.
944+
945+
For practical examples, see the methods listed above.
946+
947+
You can use a transform in an environment by passing it to the TransformedEnv constructor:
948+
949+
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), AddOneToObs())
950+
951+
You can compose multiple transforms together using the Compose class:
952+
953+
>>> transform = Compose(AddOneToObs(), RewardSum())
954+
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), transform)
955+
956+
Inverse Transforms
957+
^^^^^^^^^^^^^^^^^^
958+
959+
Some transforms have an inverse transform that can be used to undo the transformation. For example, the AddOneToAction
960+
transform has an inverse transform that subtracts 1 from the action tensor:
961+
962+
>>> class AddOneToAction(Transform):
963+
... """A transform that adds 1 to the action tensor."""
964+
... def __init__(self):
965+
... super().__init__(in_keys=[], out_keys=[], in_keys_inv=["action"], out_keys_inv=["action"])
966+
... def _inv_apply_transform(self, action: torch.Tensor) -> torch.Tensor:
967+
... return action + 1
968+
969+
Using a Transform with a Replay Buffer
970+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
971+
972+
You can use a transform with a replay buffer by passing it to the ReplayBuffer constructor:
903973

904974
Cloning transforms
905975
~~~~~~~~~~~~~~~~~~
@@ -1000,6 +1070,7 @@ to be able to create this other composition:
10001070
TargetReturn
10011071
TensorDictPrimer
10021072
TimeMaxPool
1073+
Timer
10031074
Tokenizer
10041075
ToTensorImage
10051076
TrajCounter

docs/source/reference/utils.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.. currentmodule:: torchrl._utils
1+
.. currentmodule:: torchrl
22

33
torchrl._utils package
44
====================
@@ -11,3 +11,5 @@ Set of utility methods that are used internally by the library.
1111
:template: rl_template.rst
1212

1313
implement_for
14+
set_auto_unwrap_transformed_env
15+
auto_unwrap_transformed_env

test/_utils_internal.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010

1111
import os.path
12+
import sys
1213
import time
1314
import unittest
1415
import warnings
@@ -42,12 +43,17 @@
4243
ToTensorImage,
4344
TransformedEnv,
4445
)
46+
from torchrl.modules import MLP
4547
from torchrl.objectives.value.advantages import _vmap_func
4648

4749
# Specified for test_utils.py
4850
__version__ = "0.3"
4951

50-
from torchrl.modules import MLP
52+
IS_WIN = sys.platform == "win32"
53+
if IS_WIN:
54+
mp_ctx = "spawn"
55+
else:
56+
mp_ctx = "fork"
5157

5258

5359
def CARTPOLE_VERSIONED():
@@ -265,6 +271,7 @@ def _make_envs(
265271
N,
266272
device="cpu",
267273
kwargs=None,
274+
local_mp_ctx=mp_ctx,
268275
):
269276
torch.manual_seed(0)
270277
if not transformed_in:
@@ -299,7 +306,9 @@ def create_env_fn():
299306
)
300307

301308
env0 = create_env_fn()
302-
env_parallel = ParallelEnv(N, create_env_fn, create_env_kwargs=kwargs)
309+
env_parallel = ParallelEnv(
310+
N, create_env_fn, create_env_kwargs=kwargs, mp_start_method=local_mp_ctx
311+
)
303312
env_serial = SerialEnv(N, create_env_fn, create_env_kwargs=kwargs)
304313

305314
for key in env0.observation_spec.keys(True, True):

test/test_collector.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -3093,16 +3093,23 @@ def test_dynamic_sync_collector(self):
30933093
assert isinstance(data, LazyStackedTensorDict)
30943094
assert data.names[-1] == "time"
30953095

3096-
def test_dynamic_multisync_collector(self):
3096+
@pytest.mark.parametrize("policy_device", [None, *get_default_devices()])
3097+
def test_dynamic_multisync_collector(self, policy_device):
30973098
env = EnvWithDynamicSpec
3098-
policy = RandomPolicy(env().action_spec)
3099+
spec = env().action_spec
3100+
if policy_device is not None:
3101+
spec = spec.to(policy_device)
3102+
policy = RandomPolicy(spec)
30993103
collector = MultiSyncDataCollector(
31003104
[env],
31013105
policy,
31023106
frames_per_batch=20,
31033107
total_frames=100,
31043108
use_buffers=False,
31053109
cat_results="stack",
3110+
policy_device=policy_device,
3111+
env_device="cpu",
3112+
storing_device="cpu",
31063113
)
31073114
for data in collector:
31083115
assert isinstance(data, LazyStackedTensorDict)
@@ -3213,9 +3220,11 @@ def test_cudagraph_policy(self, collector_cls, cudagraph_policy):
32133220
@pytest.mark.skipif(not _has_gym, reason="gym required for this test")
32143221
class TestCollectorsNonTensor:
32153222
class AddNontTensorData(Transform):
3216-
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
3217-
tensordict["nt"] = f"a string! - {tensordict.get('step_count').item()}"
3218-
return tensordict
3223+
def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
3224+
next_tensordict[
3225+
"nt"
3226+
] = f"a string! - {next_tensordict.get('step_count').item()}"
3227+
return next_tensordict
32193228

32203229
def _reset(
32213230
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase

0 commit comments

Comments
 (0)