@@ -865,6 +865,8 @@ The inverse process is executed with the output tensordict, where the `in_keys`
865
865
866
866
Rename transform logic
867
867
868
+ .. note :: During a call to `inv`, the transforms are executed in reversed order (compared to the forward / step mode).
869
+
868
870
Transforming Tensors and Specs
869
871
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
870
872
@@ -900,6 +902,74 @@ tensor that should not be generated when using :meth:`~torchrl.envs.EnvBase.rand
900
902
environment. Instead, `"action_discrete" ` should be generated, and its continuous counterpart obtained from the
901
903
transform. Therefore, the user should see the `"action_discrete" ` entry being exposed, but not `"action" `.
902
904
905
+ Designing your own Transform
906
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
907
+
908
+ To create a basic, custom transform, you need to subclass the `Transform ` class and implement the
909
+ :meth: `~torchrl.envs._apply_transform ` method. Here's an example of a simple transform that adds 1 to the observation
910
+ tensor:
911
+
912
+ >>> class AddOneToObs (Transform ):
913
+ ... """ A transform that adds 1 to the observation tensor."""
914
+ ...
915
+ ... def __init__ (self ):
916
+ ... super ().__init__ (in_keys = [" observation" ], out_keys = [" observation" ])
917
+ ...
918
+ ... def _apply_transform (self , obs : torch.Tensor) -> torch.Tensor:
919
+ ... return obs + 1
920
+
921
+
922
+ Tips for subclassing `Transform `
923
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
924
+
925
+ There are various ways of subclassing a transform. The things to take into considerations are:
926
+
927
+ - Is the transform identical for each tensor / item being transformed? Use
928
+ :meth: `~torchrl.envs.Transform._apply_transform ` and :meth: `~torchrl.envs.Transform._inv_apply_transform `.
929
+ - The transform needs access to the input data to env.step as well as output? Rewrite
930
+ :meth: `~torchrl.envs.Transform._step `.
931
+ Otherwise, rewrite :meth: `~torchrl.envs.Transform._call ` (or :meth: `~torchrl.envs.Transform._inv_call `).
932
+ - Is the transform to be used within a replay buffer? Overwrite :meth: `~torchrl.envs.Transform.forward `,
933
+ :meth: `~torchrl.envs.Transform.inv `, :meth: `~torchrl.envs.Transform._apply_transform ` or
934
+ :meth: `~torchrl.envs.Transform._inv_apply_transform `.
935
+ - Within a transform, you can access (and make calls to) the parent environment using
936
+ :attr: `~torchrl.envs.Transform.parent ` (the base env + all transforms till this one) or
937
+ :meth: `~torchrl.envs.Transform.container ` (The object that encapsulates the transform).
938
+ - Don't forget to edits the specs if needed: top level: :meth: `~torchrl.envs.Transform.transform_output_spec `,
939
+ :meth: `~torchrl.envs.Transform.transform_input_spec `.
940
+ Leaf level: :meth: `~torchrl.envs.Transform.transform_observation_spec `,
941
+ :meth: `~torchrl.envs.Transform.transform_action_spec `, :meth: `~torchrl.envs.Transform.transform_state_spec `,
942
+ :meth: `~torchrl.envs.Transform.transform_reward_spec ` and
943
+ :meth: `~torchrl.envs.Transform.transform_reward_spec `.
944
+
945
+ For practical examples, see the methods listed above.
946
+
947
+ You can use a transform in an environment by passing it to the TransformedEnv constructor:
948
+
949
+ >>> env = TransformedEnv(GymEnv(" Pendulum-v1" ), AddOneToObs())
950
+
951
+ You can compose multiple transforms together using the Compose class:
952
+
953
+ >>> transform = Compose(AddOneToObs(), RewardSum())
954
+ >>> env = TransformedEnv(GymEnv(" Pendulum-v1" ), transform)
955
+
956
+ Inverse Transforms
957
+ ^^^^^^^^^^^^^^^^^^
958
+
959
+ Some transforms have an inverse transform that can be used to undo the transformation. For example, the AddOneToAction
960
+ transform has an inverse transform that subtracts 1 from the action tensor:
961
+
962
+ >>> class AddOneToAction (Transform ):
963
+ ... """ A transform that adds 1 to the action tensor."""
964
+ ... def __init__ (self ):
965
+ ... super ().__init__ (in_keys = [], out_keys = [], in_keys_inv = [" action" ], out_keys_inv = [" action" ])
966
+ ... def _inv_apply_transform (self , action : torch.Tensor) -> torch.Tensor:
967
+ ... return action + 1
968
+
969
+ Using a Transform with a Replay Buffer
970
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
971
+
972
+ You can use a transform with a replay buffer by passing it to the ReplayBuffer constructor:
903
973
904
974
Cloning transforms
905
975
~~~~~~~~~~~~~~~~~~
@@ -1000,6 +1070,7 @@ to be able to create this other composition:
1000
1070
TargetReturn
1001
1071
TensorDictPrimer
1002
1072
TimeMaxPool
1073
+ Timer
1003
1074
Tokenizer
1004
1075
ToTensorImage
1005
1076
TrajCounter
0 commit comments