Skip to content

Commit 093a159

Browse files
committed
[Feature] UnaryTransform for input entries
ghstack-source-id: bb0ea97f47bdad6ba5e73692969fece4e2efbfb4 Pull Request resolved: #2700
1 parent 2c19fcc commit 093a159

File tree

6 files changed

+821
-143
lines changed

6 files changed

+821
-143
lines changed

Diff for: docs/source/reference/envs.rst

+63-12
Original file line numberDiff line numberDiff line change
@@ -731,29 +731,80 @@ pixels or states etc).
731731
Forward and inverse transforms
732732
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
733733

734-
Transforms also have an ``inv`` method that is called before
735-
the action is applied in reverse order over the composed transform chain:
736-
this allows to apply transforms to data in the environment before the action is taken
737-
in the environment. The keys to be included in this inverse transform are passed through the
738-
``"in_keys_inv"`` keyword argument:
734+
Transforms also have an :meth:`~torchrl.envs.Transform.inv` method that is called before the action is applied in reverse
735+
order over the composed transform chain. This allows applying transforms to data in the environment before the action is
736+
taken in the environment. The keys to be included in this inverse transform are passed through the `"in_keys_inv"`
737+
keyword argument, and the out-keys default to these values in most cases:
739738

740739
.. code-block::
741740
:caption: Inverse transform
742741
743742
>>> env.append_transform(DoubleToFloat(in_keys_inv=["action"])) # will map the action from float32 to float64 before calling the base_env.step
744743
745-
The way ``in_keys`` relates to ``in_keys_inv`` can be understood by considering the base environment as the "inner" part
746-
of the transform. In constrast, the user inputs and outputs to and from the transform are to be considered as the
747-
outside world. The following figure shows what this means in practice for the :class:`~torchrl.envs.RenameTransform`
748-
class: the input ``TensorDict`` of the ``step`` function must have the ``out_keys_inv`` listed in its entries as they
749-
are part of the outside world. The transform changes these names to make them match the names of the inner, base
750-
environment using the ``in_keys_inv``. The inverse process is executed with the output tensordict, where the ``in_keys``
751-
are mapped to the corresponding ``out_keys``.
744+
The following paragraphs detail how one can think about what is to be considered `in_` or `out_` features.
745+
746+
Understanding Transform Keys
747+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
748+
749+
In transforms, `in_keys` and `out_keys` define the interaction between the base environment and the outside world
750+
(e.g., your policy):
751+
752+
- `in_keys` refers to the base environment's perspective (inner = `base_env` of the
753+
:class:`~torchrl.envs.TransformedEnv`).
754+
- `out_keys` refers to the outside world (outer = `policy`, `agent`, etc.).
755+
756+
For example, with `in_keys=["obs"]` and `out_keys=["obs_standardized"]`, the policy will "see" a standardized
757+
observation, while the base environment outputs a regular observation.
758+
759+
Similarly, for inverse keys:
760+
761+
- `in_keys_inv` refers to entries as seen by the base environment.
762+
- `out_keys_inv` refers to entries as seen or produced by the policy.
763+
764+
The following figure illustrates this concept for the :class:`~torchrl.envs.RenameTransform` class: the input
765+
`TensorDict` of the `step` function must include the `out_keys_inv` as they are part of the outside world. The
766+
transform changes these names to match the names of the inner, base environment using the `in_keys_inv`.
767+
The inverse process is executed with the output tensordict, where the `in_keys` are mapped to the corresponding
768+
`out_keys`.
752769

753770
.. figure:: /_static/img/rename_transform.png
754771

755772
Rename transform logic
756773

774+
Transforming Tensors and Specs
775+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
776+
777+
When transforming actual tensors (coming from the policy), the process is schematically represented as:
778+
779+
>>> for t in reversed(self.transform):
780+
... td = t.inv(td)
781+
782+
This starts with the outermost transform to the innermost transform, ensuring the action value exposed to the policy
783+
is properly transformed.
784+
785+
For transforming the action spec, the process should go from innermost to outermost (similar to observation specs):
786+
787+
>>> def transform_action_spec(self, action_spec):
788+
... for t in self.transform:
789+
... action_spec = t.transform_action_spec(action_spec)
790+
... return action_spec
791+
792+
A pseudocode for a single transform_action_spec could be:
793+
794+
>>> def transform_action_spec(self, action_spec):
795+
... return spec_from_random_values(self._apply_transform(action_spec.rand()))
796+
797+
This approach ensures that the "outside" spec is inferred from the "inside" spec. Note that we did not call
798+
`_inv_apply_transform` but `_apply_transform` on purpose!
799+
800+
Exposing Specs to the Outside World
801+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
802+
803+
`TransformedEnv` will expose the specs corresponding to the `out_keys_inv` for actions and states.
804+
For example, with :class:`~torchrl.envs.ActionDiscretizer`, the environment's action (e.g., `"action"`) is a float-valued
805+
tensor that should not be generated when using :meth:`~torchrl.envs.EnvBase.rand_action` with the transformed
806+
environment. Instead, `"action_discrete"` should be generated, and its continuous counterpart obtained from the
807+
transform. Therefore, the user should see the `"action_discrete"` entry being exposed, but not `"action"`.
757808

758809

759810
Cloning transforms

Diff for: test/mocking_classes.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1070,17 +1070,20 @@ def _step(
10701070

10711071
class CountingEnvWithString(CountingEnv):
10721072
def __init__(self, *args, **kwargs):
1073+
self.max_size = kwargs.pop("max_size", 30)
1074+
self.min_size = kwargs.pop("min_size", 4)
10731075
super().__init__(*args, **kwargs)
10741076
self.observation_spec.set(
10751077
"string",
10761078
NonTensor(
10771079
shape=self.batch_size,
10781080
device=self.device,
1081+
example_data=self.get_random_string(),
10791082
),
10801083
)
10811084

10821085
def get_random_string(self):
1083-
size = random.randint(4, 30)
1086+
size = random.randint(self.min_size, self.max_size)
10841087
return "".join(random.choice(string.ascii_lowercase) for _ in range(size))
10851088

10861089
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:

0 commit comments

Comments
 (0)