Skip to content

Commit a721d28

Browse files
committed
[Feature] Transform for partial steps
ghstack-source-id: 8110f09 Pull Request resolved: #2777
1 parent f1c42e0 commit a721d28

File tree

10 files changed

+762
-112
lines changed

10 files changed

+762
-112
lines changed

.github/workflows/test-linux.yml

-32
Original file line numberDiff line numberDiff line change
@@ -53,38 +53,6 @@ jobs:
5353
## setup_env.sh
5454
bash .github/unittest/linux/scripts/run_all.sh
5555
56-
tests-cpu-oldget:
57-
# Tests that TD_GET_DEFAULTS_TO_NONE=0 works fine as this will be the default for TD up to 0.7
58-
strategy:
59-
matrix:
60-
python_version: ["3.12"]
61-
fail-fast: false
62-
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
63-
with:
64-
runner: linux.12xlarge
65-
repository: pytorch/rl
66-
docker-image: "nvidia/cuda:12.2.0-devel-ubuntu22.04"
67-
timeout: 90
68-
script: |
69-
if [[ "${{ github.ref }}" =~ release/* ]]; then
70-
export RELEASE=1
71-
export TORCH_VERSION=stable
72-
else
73-
export RELEASE=0
74-
export TORCH_VERSION=nightly
75-
fi
76-
export TD_GET_DEFAULTS_TO_NONE=0
77-
78-
# Set env vars from matrix
79-
export PYTHON_VERSION=${{ matrix.python_version }}
80-
export CU_VERSION="cpu"
81-
82-
echo "PYTHON_VERSION: $PYTHON_VERSION"
83-
echo "CU_VERSION: $CU_VERSION"
84-
85-
## setup_env.sh
86-
bash .github/unittest/linux/scripts/run_all.sh
87-
8856
tests-gpu:
8957
strategy:
9058
matrix:

docs/source/reference/envs.rst

+76
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,81 @@ provides more information on how to design a custom environment from scratch.
163163
GymLikeEnv
164164
EnvMetaData
165165

166+
Partial steps and partial resets
167+
--------------------------------
168+
169+
TorchRL allows environments to reset some but not all the environments, or run a step in one but not all environments.
170+
If there is only one environment in the batch, then a partial reset / step is also allowed with the behavior detailed
171+
below.
172+
173+
Batching environments and locking the batch
174+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
175+
176+
.. _ref_batch_locked:
177+
178+
Before detailing what partial resets and partial steps do, we must distinguish cases where an environment has
179+
a batch size of its own (mostly stateful environments) or when the environment is just a mere module that, given an
180+
input of arbitrary size, batches the operations over all elements (mostly stateless environments).
181+
182+
This is controlled via the :attr:`~torchrl.envs.batch_locked` attribute: a batch-locked environment requires all input
183+
tensordicts to have the same batch-size as the env's. Typical examples of these environments are
184+
:class:`~torchrl.envs.GymEnv` and related. Batch-unlocked envs are by contrast allowed to work with any input size.
185+
Notable examples are :class:`~torchrl.envs.BraxEnv` or :class:`~torchrl.envs.JumanjiEnv`.
186+
187+
Executing partial steps in a batch-unlocked environment is straightforward: one just needs to mask the part of the
188+
tensordict that does not need to be executed, pass the other part to `step` and merge the results with the previous
189+
input.
190+
191+
Batched environments (:class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv`) can also deal with
192+
partial steps easily, they just pass the actions to the sub-environments that are required to be executed.
193+
194+
In all other cases, TorchRL assumes that the environment handles the partial steps correctly.
195+
196+
.. warning:: This means that custom environments may silently run the non-required steps as there is no way for torchrl
197+
to control what happens within the `_step` method!
198+
199+
Partial Steps
200+
~~~~~~~~~~~~~
201+
202+
.. _ref_partial_steps:
203+
204+
Partial steps are controlled via the temporary key `"_step"` which points to a boolean mask of the
205+
size of the tensordict that holds it. The classes armed to deal with this are:
206+
207+
- Batched environments: :class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv` will dispatch the
208+
action to and only to the environments where `"_step"` is `True`;
209+
- Batch-unlocked environments;
210+
- Unbatched environments (i.e., environments without batch size). In these environments, the :meth:`~torchrl.envs.EnvBase.step`
211+
method will first look for a `"_step"` entry and, if present, act accordingly.
212+
If a :class:`~torchrl.envs.Transform` instance passes a `"_step"` entry to the tensordict, it is also captured by
213+
:class:`~torchrl.envs.TransformedEnv`'s own `_step` method which will skip the `base_env.step` as well as any further
214+
transformation.
215+
216+
When dealing with partial steps, the strategy is always to use the step output and mask missing values with the previous
217+
content of the input tensordict, if present, or a `0`-valued tensor if the tensor cannot be found. This means that
218+
if the input tensordict does not contain all the previous observations, then the output tensordict will be 0-valued for
219+
all the non-stepped elements. Within batched environments, data collectors and rollouts utils, this is an issue that
220+
is not observed because these classes handle the passing of data properly.
221+
222+
Partial steps are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_all_done` is `True`,
223+
as the environments with a `True` done state will need to be skipped during calls to `_step`.
224+
225+
The :class:`~torchrl.envs.ConditionalSkip` transform allows you to programmatically ask for (partial) step skips.
226+
227+
Partial Resets
228+
~~~~~~~~~~~~~~
229+
230+
.. _ref_partial_resets:
231+
232+
Partial resets work pretty much like partial steps, but with the `"_reset"` entry.
233+
234+
The same restrictions of partial steps apply to partial resets.
235+
236+
Likewise, partial resets are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_any_done` is `True`,
237+
as the environments with a `True` done state will need to be reset, but not others.
238+
239+
See te following paragraph for a deep dive in partial resets within batched and vectorized environments.
240+
166241
Vectorized envs
167242
---------------
168243

@@ -886,6 +961,7 @@ to be able to create this other composition:
886961
CenterCrop
887962
ClipTransform
888963
Compose
964+
ConditionalSkip
889965
Crop
890966
DTypeCastTransform
891967
DeviceCastTransform

test/mocking_classes.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,9 @@ def _step(
635635
while done.shape != tensordict.shape:
636636
done = done.any(-1)
637637
done = reward = done.unsqueeze(-1)
638-
tensordict.set("reward", reward.to(torch.get_default_dtype()))
638+
tensordict.set(
639+
"reward", reward.to(self.reward_spec.dtype).expand(self.reward_spec.shape)
640+
)
639641
tensordict.set("done", done)
640642
tensordict.set("terminated", done)
641643
return tensordict

test/test_env.py

+47-35
Original file line numberDiff line numberDiff line change
@@ -4112,17 +4112,21 @@ def test_parallel_partial_steps(
41124112
use_buffers=use_buffers,
41134113
device=device,
41144114
)
4115-
td = penv.reset()
4116-
psteps = torch.zeros(4, dtype=torch.bool)
4117-
psteps[[1, 3]] = True
4118-
td.set("_step", psteps)
4119-
4120-
td.set("action", penv.full_action_spec[penv.action_key].one())
4121-
td = penv.step(td)
4122-
assert (td[0].get("next") == 0).all()
4123-
assert (td[1].get("next") != 0).any()
4124-
assert (td[2].get("next") == 0).all()
4125-
assert (td[3].get("next") != 0).any()
4115+
try:
4116+
td = penv.reset()
4117+
psteps = torch.zeros(4, dtype=torch.bool)
4118+
psteps[[1, 3]] = True
4119+
td.set("_step", psteps)
4120+
4121+
td.set("action", penv.full_action_spec[penv.action_key].one())
4122+
td = penv.step(td)
4123+
assert_allclose_td(td[0].get("next"), td[0], intersection=True)
4124+
assert (td[1].get("next") != 0).any()
4125+
assert_allclose_td(td[2].get("next"), td[2], intersection=True)
4126+
assert (td[3].get("next") != 0).any()
4127+
finally:
4128+
penv.close()
4129+
del penv
41264130

41274131
@pytest.mark.parametrize("use_buffers", [False, True])
41284132
def test_parallel_partial_step_and_maybe_reset(
@@ -4135,17 +4139,21 @@ def test_parallel_partial_step_and_maybe_reset(
41354139
use_buffers=use_buffers,
41364140
device=device,
41374141
)
4138-
td = penv.reset()
4139-
psteps = torch.zeros(4, dtype=torch.bool)
4140-
psteps[[1, 3]] = True
4141-
td.set("_step", psteps)
4142-
4143-
td.set("action", penv.full_action_spec[penv.action_key].one())
4144-
td, tdreset = penv.step_and_maybe_reset(td)
4145-
assert (td[0].get("next") == 0).all()
4146-
assert (td[1].get("next") != 0).any()
4147-
assert (td[2].get("next") == 0).all()
4148-
assert (td[3].get("next") != 0).any()
4142+
try:
4143+
td = penv.reset()
4144+
psteps = torch.zeros(4, dtype=torch.bool)
4145+
psteps[[1, 3]] = True
4146+
td.set("_step", psteps)
4147+
4148+
td.set("action", penv.full_action_spec[penv.action_key].one())
4149+
td, tdreset = penv.step_and_maybe_reset(td)
4150+
assert_allclose_td(td[0].get("next"), td[0], intersection=True)
4151+
assert (td[1].get("next") != 0).any()
4152+
assert_allclose_td(td[2].get("next"), td[2], intersection=True)
4153+
assert (td[3].get("next") != 0).any()
4154+
finally:
4155+
penv.close()
4156+
del penv
41494157

41504158
@pytest.mark.parametrize("use_buffers", [False, True])
41514159
def test_serial_partial_steps(self, use_buffers, device, env_device):
@@ -4156,17 +4164,21 @@ def test_serial_partial_steps(self, use_buffers, device, env_device):
41564164
use_buffers=use_buffers,
41574165
device=device,
41584166
)
4159-
td = penv.reset()
4160-
psteps = torch.zeros(4, dtype=torch.bool)
4161-
psteps[[1, 3]] = True
4162-
td.set("_step", psteps)
4163-
4164-
td.set("action", penv.full_action_spec[penv.action_key].one())
4165-
td = penv.step(td)
4166-
assert (td[0].get("next") == 0).all()
4167-
assert (td[1].get("next") != 0).any()
4168-
assert (td[2].get("next") == 0).all()
4169-
assert (td[3].get("next") != 0).any()
4167+
try:
4168+
td = penv.reset()
4169+
psteps = torch.zeros(4, dtype=torch.bool)
4170+
psteps[[1, 3]] = True
4171+
td.set("_step", psteps)
4172+
4173+
td.set("action", penv.full_action_spec[penv.action_key].one())
4174+
td = penv.step(td)
4175+
assert_allclose_td(td[0].get("next"), td[0], intersection=True)
4176+
assert (td[1].get("next") != 0).any()
4177+
assert_allclose_td(td[2].get("next"), td[2], intersection=True)
4178+
assert (td[3].get("next") != 0).any()
4179+
finally:
4180+
penv.close()
4181+
del penv
41704182

41714183
@pytest.mark.parametrize("use_buffers", [False, True])
41724184
def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_device):
@@ -4184,9 +4196,9 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
41844196

41854197
td.set("action", penv.full_action_spec[penv.action_key].one())
41864198
td = penv.step(td)
4187-
assert (td[0].get("next") == 0).all()
4199+
assert_allclose_td(td[0].get("next"), td[0], intersection=True)
41884200
assert (td[1].get("next") != 0).any()
4189-
assert (td[2].get("next") == 0).all()
4201+
assert_allclose_td(td[2].get("next"), td[2], intersection=True)
41904202
assert (td[3].get("next") != 0).any()
41914203

41924204

0 commit comments

Comments
 (0)