Skip to content

Commit 7c034e3

Browse files
committed
[Feature] Transform for partial steps
ghstack-source-id: 587f91e33dfe1d59b73c4b2f2f1c21760ee79d2e Pull Request resolved: #2777
1 parent f1c42e0 commit 7c034e3

File tree

10 files changed

+778
-112
lines changed

10 files changed

+778
-112
lines changed

.github/workflows/test-linux.yml

-32
Original file line numberDiff line numberDiff line change
@@ -53,38 +53,6 @@ jobs:
5353
## setup_env.sh
5454
bash .github/unittest/linux/scripts/run_all.sh
5555
56-
tests-cpu-oldget:
57-
# Tests that TD_GET_DEFAULTS_TO_NONE=0 works fine as this will be the default for TD up to 0.7
58-
strategy:
59-
matrix:
60-
python_version: ["3.12"]
61-
fail-fast: false
62-
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
63-
with:
64-
runner: linux.12xlarge
65-
repository: pytorch/rl
66-
docker-image: "nvidia/cuda:12.2.0-devel-ubuntu22.04"
67-
timeout: 90
68-
script: |
69-
if [[ "${{ github.ref }}" =~ release/* ]]; then
70-
export RELEASE=1
71-
export TORCH_VERSION=stable
72-
else
73-
export RELEASE=0
74-
export TORCH_VERSION=nightly
75-
fi
76-
export TD_GET_DEFAULTS_TO_NONE=0
77-
78-
# Set env vars from matrix
79-
export PYTHON_VERSION=${{ matrix.python_version }}
80-
export CU_VERSION="cpu"
81-
82-
echo "PYTHON_VERSION: $PYTHON_VERSION"
83-
echo "CU_VERSION: $CU_VERSION"
84-
85-
## setup_env.sh
86-
bash .github/unittest/linux/scripts/run_all.sh
87-
8856
tests-gpu:
8957
strategy:
9058
matrix:

docs/source/reference/envs.rst

+76
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,81 @@ provides more information on how to design a custom environment from scratch.
163163
GymLikeEnv
164164
EnvMetaData
165165

166+
Partial steps and partial resets
167+
--------------------------------
168+
169+
TorchRL allows environments to reset some but not all the environments, or run a step in one but not all environments.
170+
If there is only one environment in the batch, then a partial reset / step is also allowed with the behavior detailed
171+
below.
172+
173+
Batching environments and locking the batch
174+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
175+
176+
.. _ref_batch_locked:
177+
178+
Before detailing what partial resets and partial steps do, we must distinguish cases where an environment has
179+
a batch size of its own (mostly stateful environments) or when the environment is just a mere module that, given an
180+
input of arbitrary size, batches the operations over all elements (mostly stateless environments).
181+
182+
This is controlled via the :attr:`~torchrl.envs.batch_locked` attribute: a batch-locked environment requires all input
183+
tensordicts to have the same batch-size as the env's. Typical examples of these environments are
184+
:class:`~torchrl.envs.GymEnv` and related. Batch-unlocked envs are by contrast allowed to work with any input size.
185+
Notable examples are :class:`~torchrl.envs.BraxEnv` or :class:`~torchrl.envs.JumanjiEnv`.
186+
187+
Executing partial steps in a batch-unlocked environment is straightforward: one just needs to mask the part of the
188+
tensordict that does not need to be executed, pass the other part to `step` and merge the results with the previous
189+
input.
190+
191+
Batched environments (:class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv`) can also deal with
192+
partial steps easily, they just pass the actions to the sub-environments that are required to be executed.
193+
194+
In all other cases, TorchRL assumes that the environment handles the partial steps correctly.
195+
196+
.. warning:: This means that custom environments may silently run the non-required steps as there is no way for torchrl
197+
to control what happens within the `_step` method!
198+
199+
Partial Steps
200+
~~~~~~~~~~~~~
201+
202+
.. _ref_partial_steps:
203+
204+
Partial steps are controlled via the temporary key `"_step"` which points to a boolean mask of the
205+
size of the tensordict that holds it. The classes armed to deal with this are:
206+
207+
- Batched environments: :class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv` will dispatch the
208+
action to and only to the environments where `"_step"` is `True`;
209+
- Batch-unlocked environments;
210+
- Unbatched environments (i.e., environments without batch size). In these environments, the :meth:`~torchrl.envs.EnvBase.step`
211+
method will first look for a `"_step"` entry and, if present, act accordingly.
212+
If a :class:`~torchrl.envs.Transform` instance passes a `"_step"` entry to the tensordict, it is also captured by
213+
:class:`~torchrl.envs.TransformedEnv`'s own `_step` method which will skip the `base_env.step` as well as any further
214+
transformation.
215+
216+
When dealing with partial steps, the strategy is always to use the step output and mask missing values with the previous
217+
content of the input tensordict, if present, or a `0`-valued tensor if the tensor cannot be found. This means that
218+
if the input tensordict does not contain all the previous observations, then the output tensordict will be 0-valued for
219+
all the non-stepped elements. Within batched environments, data collectors and rollouts utils, this is an issue that
220+
is not observed because these classes handle the passing of data properly.
221+
222+
Partial steps are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_all_done` is `True`,
223+
as the environments with a `True` done state will need to be skipped during calls to `_step`.
224+
225+
The :class:`~torchrl.envs.ConditionalSkip` transform allows you to programmatically ask for (partial) step skips.
226+
227+
Partial Resets
228+
~~~~~~~~~~~~~~
229+
230+
.. _ref_partial_resets:
231+
232+
Partial resets work pretty much like partial steps, but with the `"_reset"` entry.
233+
234+
The same restrictions of partial steps apply to partial resets.
235+
236+
Likewise, partial resets are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_any_done` is `True`,
237+
as the environments with a `True` done state will need to be reset, but not others.
238+
239+
See te following paragraph for a deep dive in partial resets within batched and vectorized environments.
240+
166241
Vectorized envs
167242
---------------
168243

@@ -886,6 +961,7 @@ to be able to create this other composition:
886961
CenterCrop
887962
ClipTransform
888963
Compose
964+
ConditionalSkip
889965
Crop
890966
DTypeCastTransform
891967
DeviceCastTransform

test/mocking_classes.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,8 @@ def _step(
538538

539539

540540
class ContinuousActionVecMockEnv(_MockEnv):
541+
adapt_dtype: bool = True
542+
541543
@classmethod
542544
def __new__(
543545
cls,
@@ -635,7 +637,14 @@ def _step(
635637
while done.shape != tensordict.shape:
636638
done = done.any(-1)
637639
done = reward = done.unsqueeze(-1)
638-
tensordict.set("reward", reward.to(torch.get_default_dtype()))
640+
tensordict.set(
641+
"reward",
642+
reward.to(
643+
self.reward_spec.dtype
644+
if self.adapt_dtype
645+
else torch.get_default_dtype()
646+
).expand(self.reward_spec.shape),
647+
)
639648
tensordict.set("done", done)
640649
tensordict.set("terminated", done)
641650
return tensordict

test/test_env.py

+48-35
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@
235235
class TestEnvBase:
236236
def test_run_type_checks(self):
237237
env = ContinuousActionVecMockEnv()
238+
env.adapt_dtype = False
238239
env._run_type_checks = False
239240
check_env_specs(env)
240241
env._run_type_checks = True
@@ -4112,17 +4113,21 @@ def test_parallel_partial_steps(
41124113
use_buffers=use_buffers,
41134114
device=device,
41144115
)
4115-
td = penv.reset()
4116-
psteps = torch.zeros(4, dtype=torch.bool)
4117-
psteps[[1, 3]] = True
4118-
td.set("_step", psteps)
4119-
4120-
td.set("action", penv.full_action_spec[penv.action_key].one())
4121-
td = penv.step(td)
4122-
assert (td[0].get("next") == 0).all()
4123-
assert (td[1].get("next") != 0).any()
4124-
assert (td[2].get("next") == 0).all()
4125-
assert (td[3].get("next") != 0).any()
4116+
try:
4117+
td = penv.reset()
4118+
psteps = torch.zeros(4, dtype=torch.bool)
4119+
psteps[[1, 3]] = True
4120+
td.set("_step", psteps)
4121+
4122+
td.set("action", penv.full_action_spec[penv.action_key].one())
4123+
td = penv.step(td)
4124+
assert_allclose_td(td[0].get("next"), td[0], intersection=True)
4125+
assert (td[1].get("next") != 0).any()
4126+
assert_allclose_td(td[2].get("next"), td[2], intersection=True)
4127+
assert (td[3].get("next") != 0).any()
4128+
finally:
4129+
penv.close()
4130+
del penv
41264131

41274132
@pytest.mark.parametrize("use_buffers", [False, True])
41284133
def test_parallel_partial_step_and_maybe_reset(
@@ -4135,17 +4140,21 @@ def test_parallel_partial_step_and_maybe_reset(
41354140
use_buffers=use_buffers,
41364141
device=device,
41374142
)
4138-
td = penv.reset()
4139-
psteps = torch.zeros(4, dtype=torch.bool)
4140-
psteps[[1, 3]] = True
4141-
td.set("_step", psteps)
4142-
4143-
td.set("action", penv.full_action_spec[penv.action_key].one())
4144-
td, tdreset = penv.step_and_maybe_reset(td)
4145-
assert (td[0].get("next") == 0).all()
4146-
assert (td[1].get("next") != 0).any()
4147-
assert (td[2].get("next") == 0).all()
4148-
assert (td[3].get("next") != 0).any()
4143+
try:
4144+
td = penv.reset()
4145+
psteps = torch.zeros(4, dtype=torch.bool, device=td.get("done").device)
4146+
psteps[[1, 3]] = True
4147+
td.set("_step", psteps)
4148+
4149+
td.set("action", penv.full_action_spec[penv.action_key].one())
4150+
td, tdreset = penv.step_and_maybe_reset(td)
4151+
assert_allclose_td(td[0].get("next"), td[0], intersection=True)
4152+
assert (td[1].get("next") != 0).any()
4153+
assert_allclose_td(td[2].get("next"), td[2], intersection=True)
4154+
assert (td[3].get("next") != 0).any()
4155+
finally:
4156+
penv.close()
4157+
del penv
41494158

41504159
@pytest.mark.parametrize("use_buffers", [False, True])
41514160
def test_serial_partial_steps(self, use_buffers, device, env_device):
@@ -4156,17 +4165,21 @@ def test_serial_partial_steps(self, use_buffers, device, env_device):
41564165
use_buffers=use_buffers,
41574166
device=device,
41584167
)
4159-
td = penv.reset()
4160-
psteps = torch.zeros(4, dtype=torch.bool)
4161-
psteps[[1, 3]] = True
4162-
td.set("_step", psteps)
4163-
4164-
td.set("action", penv.full_action_spec[penv.action_key].one())
4165-
td = penv.step(td)
4166-
assert (td[0].get("next") == 0).all()
4167-
assert (td[1].get("next") != 0).any()
4168-
assert (td[2].get("next") == 0).all()
4169-
assert (td[3].get("next") != 0).any()
4168+
try:
4169+
td = penv.reset()
4170+
psteps = torch.zeros(4, dtype=torch.bool)
4171+
psteps[[1, 3]] = True
4172+
td.set("_step", psteps)
4173+
4174+
td.set("action", penv.full_action_spec[penv.action_key].one())
4175+
td = penv.step(td)
4176+
assert_allclose_td(td[0].get("next"), td[0], intersection=True)
4177+
assert (td[1].get("next") != 0).any()
4178+
assert_allclose_td(td[2].get("next"), td[2], intersection=True)
4179+
assert (td[3].get("next") != 0).any()
4180+
finally:
4181+
penv.close()
4182+
del penv
41704183

41714184
@pytest.mark.parametrize("use_buffers", [False, True])
41724185
def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_device):
@@ -4184,9 +4197,9 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
41844197

41854198
td.set("action", penv.full_action_spec[penv.action_key].one())
41864199
td = penv.step(td)
4187-
assert (td[0].get("next") == 0).all()
4200+
assert_allclose_td(td[0].get("next"), td[0], intersection=True)
41884201
assert (td[1].get("next") != 0).any()
4189-
assert (td[2].get("next") == 0).all()
4202+
assert_allclose_td(td[2].get("next"), td[2], intersection=True)
41904203
assert (td[3].get("next") != 0).any()
41914204

41924205

0 commit comments

Comments
 (0)