Skip to content

Commit 6508658

Browse files
committed
Update (base update)
[ghstack-poisoned]
1 parent b27ee6d commit 6508658

File tree

12 files changed

+1188
-151
lines changed

12 files changed

+1188
-151
lines changed

.github/workflows/test-linux.yml

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -53,38 +53,6 @@ jobs:
5353
## setup_env.sh
5454
bash .github/unittest/linux/scripts/run_all.sh
5555
56-
tests-cpu-oldget:
57-
# Tests that TD_GET_DEFAULTS_TO_NONE=0 works fine as this will be the default for TD up to 0.7
58-
strategy:
59-
matrix:
60-
python_version: ["3.12"]
61-
fail-fast: false
62-
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
63-
with:
64-
runner: linux.12xlarge
65-
repository: pytorch/rl
66-
docker-image: "nvidia/cuda:12.2.0-devel-ubuntu22.04"
67-
timeout: 90
68-
script: |
69-
if [[ "${{ github.ref }}" =~ release/* ]]; then
70-
export RELEASE=1
71-
export TORCH_VERSION=stable
72-
else
73-
export RELEASE=0
74-
export TORCH_VERSION=nightly
75-
fi
76-
export TD_GET_DEFAULTS_TO_NONE=0
77-
78-
# Set env vars from matrix
79-
export PYTHON_VERSION=${{ matrix.python_version }}
80-
export CU_VERSION="cpu"
81-
82-
echo "PYTHON_VERSION: $PYTHON_VERSION"
83-
echo "CU_VERSION: $CU_VERSION"
84-
85-
## setup_env.sh
86-
bash .github/unittest/linux/scripts/run_all.sh
87-
8856
tests-gpu:
8957
strategy:
9058
matrix:

docs/source/reference/envs.rst

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,81 @@ provides more information on how to design a custom environment from scratch.
163163
GymLikeEnv
164164
EnvMetaData
165165

166+
Partial steps and partial resets
167+
--------------------------------
168+
169+
TorchRL allows environments to reset some but not all the environments, or run a step in one but not all environments.
170+
If there is only one environment in the batch, then a partial reset / step is also allowed with the behavior detailed
171+
below.
172+
173+
Batching environments and locking the batch
174+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
175+
176+
.. _ref_batch_locked:
177+
178+
Before detailing what partial resets and partial steps do, we must distinguish cases where an environment has
179+
a batch size of its own (mostly stateful environments) or when the environment is just a mere module that, given an
180+
input of arbitrary size, batches the operations over all elements (mostly stateless environments).
181+
182+
This is controlled via the :attr:`~torchrl.envs.batch_locked` attribute: a batch-locked environment requires all input
183+
tensordicts to have the same batch-size as the env's. Typical examples of these environments are
184+
:class:`~torchrl.envs.GymEnv` and related. Batch-unlocked envs are by contrast allowed to work with any input size.
185+
Notable examples are :class:`~torchrl.envs.BraxEnv` or :class:`~torchrl.envs.JumanjiEnv`.
186+
187+
Executing partial steps in a batch-unlocked environment is straightforward: one just needs to mask the part of the
188+
tensordict that does not need to be executed, pass the other part to `step` and merge the results with the previous
189+
input.
190+
191+
Batched environments (:class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv`) can also deal with
192+
partial steps easily, they just pass the actions to the sub-environments that are required to be executed.
193+
194+
In all other cases, TorchRL assumes that the environment handles the partial steps correctly.
195+
196+
.. warning:: This means that custom environments may silently run the non-required steps as there is no way for torchrl
197+
to control what happens within the `_step` method!
198+
199+
Partial Steps
200+
~~~~~~~~~~~~~
201+
202+
.. _ref_partial_steps:
203+
204+
Partial steps are controlled via the temporary key `"_step"` which points to a boolean mask of the
205+
size of the tensordict that holds it. The classes armed to deal with this are:
206+
207+
- Batched environments: :class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv` will dispatch the
208+
action to and only to the environments where `"_step"` is `True`;
209+
- Batch-unlocked environments;
210+
- Unbatched environments (i.e., environments without batch size). In these environments, the :meth:`~torchrl.envs.EnvBase.step`
211+
method will first look for a `"_step"` entry and, if present, act accordingly.
212+
If a :class:`~torchrl.envs.Transform` instance passes a `"_step"` entry to the tensordict, it is also captured by
213+
:class:`~torchrl.envs.TransformedEnv`'s own `_step` method which will skip the `base_env.step` as well as any further
214+
transformation.
215+
216+
When dealing with partial steps, the strategy is always to use the step output and mask missing values with the previous
217+
content of the input tensordict, if present, or a `0`-valued tensor if the tensor cannot be found. This means that
218+
if the input tensordict does not contain all the previous observations, then the output tensordict will be 0-valued for
219+
all the non-stepped elements. Within batched environments, data collectors and rollouts utils, this is an issue that
220+
is not observed because these classes handle the passing of data properly.
221+
222+
Partial steps are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_all_done` is `True`,
223+
as the environments with a `True` done state will need to be skipped during calls to `_step`.
224+
225+
The :class:`~torchrl.envs.ConditionalSkip` transform allows you to programmatically ask for (partial) step skips.
226+
227+
Partial Resets
228+
~~~~~~~~~~~~~~
229+
230+
.. _ref_partial_resets:
231+
232+
Partial resets work pretty much like partial steps, but with the `"_reset"` entry.
233+
234+
The same restrictions of partial steps apply to partial resets.
235+
236+
Likewise, partial resets are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_any_done` is `True`,
237+
as the environments with a `True` done state will need to be reset, but not others.
238+
239+
See te following paragraph for a deep dive in partial resets within batched and vectorized environments.
240+
166241
Vectorized envs
167242
---------------
168243

@@ -212,6 +287,7 @@ component (sub-environments or agents) should be reset.
212287
This allows to reset some but not all of the components.
213288

214289
The ``"_reset"`` key has two distinct functionalities:
290+
215291
1. During a call to :meth:`~.EnvBase._reset`, the ``"_reset"`` key may or may
216292
not be present in the input tensordict. TorchRL's convention is that the
217293
absence of the ``"_reset"`` key at a given ``"done"`` level indicates
@@ -885,6 +961,7 @@ to be able to create this other composition:
885961
CenterCrop
886962
ClipTransform
887963
Compose
964+
ConditionalSkip
888965
Crop
889966
DTypeCastTransform
890967
DeviceCastTransform
@@ -900,6 +977,7 @@ to be able to create this other composition:
900977
InitTracker
901978
KLRewardTransform
902979
LineariseReward
980+
MultiAction
903981
NoopResetEnv
904982
ObservationNorm
905983
ObservationTransform

test/mocking_classes.py

Lines changed: 94 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -358,13 +358,11 @@ def _step(self, tensordict):
358358
leading_batch_size = tensordict.shape if tensordict is not None else []
359359
self.counter += 1
360360
# We use tensordict.batch_size instead of self.batch_size since this method will also be used by MockBatchedUnLockedEnv
361-
n = (
362-
torch.full(
363-
[*leading_batch_size, *self.observation_spec["observation"].shape],
364-
self.counter,
365-
)
366-
.to(self.device)
367-
.to(torch.get_default_dtype())
361+
n = torch.full(
362+
[*leading_batch_size, *self.observation_spec["observation"].shape],
363+
self.counter,
364+
device=self.device,
365+
dtype=torch.get_default_dtype(),
368366
)
369367
done = self.counter >= self.max_val
370368
done = torch.full(
@@ -391,13 +389,11 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
391389
else:
392390
leading_batch_size = tensordict.shape if tensordict is not None else []
393391

394-
n = (
395-
torch.full(
396-
[*leading_batch_size, *self.observation_spec["observation"].shape],
397-
self.counter,
398-
)
399-
.to(self.device)
400-
.to(torch.get_default_dtype())
392+
n = torch.full(
393+
[*leading_batch_size, *self.observation_spec["observation"].shape],
394+
self.counter,
395+
device=self.device,
396+
dtype=torch.get_default_dtype(),
401397
)
402398
done = self.counter >= self.max_val
403399
done = torch.full(
@@ -417,7 +413,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
417413

418414

419415
class MockBatchedUnLockedEnv(MockBatchedLockedEnv):
420-
"""Mocks an env whose batch_size does not define the size of the output tensordict.
416+
"""Mocks an env which batch_size does not define the size of the output tensordict.
421417
422418
The size of the output tensordict is defined by the input tensordict itself.
423419
@@ -433,6 +429,89 @@ def __new__(cls, *args, **kwargs):
433429
return super().__new__(cls, *args, _batch_locked=False, **kwargs)
434430

435431

432+
class StateLessCountingEnv(EnvBase):
433+
def __init__(self):
434+
self.observation_spec = Composite(
435+
count=Unbounded((1,), dtype=torch.int32),
436+
max_count=Unbounded((1,), dtype=torch.int32),
437+
)
438+
self.full_action_spec = Composite(
439+
action=Unbounded((1,), dtype=torch.int32),
440+
)
441+
self.full_done_spec = Composite(
442+
done=Unbounded((1,), dtype=torch.bool),
443+
termindated=Unbounded((1,), dtype=torch.bool),
444+
truncated=Unbounded((1,), dtype=torch.bool),
445+
)
446+
self.reward_spec = Composite(reward=Unbounded((1,), dtype=torch.float))
447+
super().__init__()
448+
self._batch_locked = False
449+
450+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
451+
452+
max_count = None
453+
count = None
454+
if tensordict is not None:
455+
max_count = tensordict.get("max_count")
456+
count = tensordict.get("count")
457+
tensordict = TensorDict(
458+
batch_size=tensordict.batch_size, device=tensordict.device
459+
)
460+
shape = tensordict.batch_size
461+
else:
462+
shape = ()
463+
tensordict = TensorDict(device=self.device)
464+
tensordict.update(
465+
TensorDict(
466+
count=torch.zeros(
467+
(
468+
*shape,
469+
1,
470+
),
471+
dtype=torch.int32,
472+
)
473+
if count is None
474+
else count,
475+
max_count=torch.randint(
476+
10,
477+
20,
478+
(
479+
*shape,
480+
1,
481+
),
482+
dtype=torch.int32,
483+
)
484+
if max_count is None
485+
else max_count,
486+
**self.done_spec.zero(shape),
487+
**self.full_reward_spec.zero(shape),
488+
)
489+
)
490+
return tensordict
491+
492+
def _step(
493+
self,
494+
tensordict: TensorDictBase,
495+
) -> TensorDictBase:
496+
action = tensordict["action"]
497+
count = tensordict["count"] + action
498+
terminated = done = count >= tensordict["max_count"]
499+
truncated = torch.zeros_like(done)
500+
return TensorDict(
501+
count=count,
502+
max_count=tensordict["max_count"],
503+
done=done,
504+
terminated=terminated,
505+
truncated=truncated,
506+
reward=self.reward_spec.zero(tensordict.shape),
507+
batch_size=tensordict.batch_size,
508+
device=tensordict.device,
509+
)
510+
511+
def _set_seed(self, seed: Optional[int]):
512+
...
513+
514+
436515
class DiscreteActionVecMockEnv(_MockEnv):
437516
@classmethod
438517
def __new__(

test/test_env.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4112,17 +4112,21 @@ def test_parallel_partial_steps(
41124112
use_buffers=use_buffers,
41134113
device=device,
41144114
)
4115-
td = penv.reset()
4116-
psteps = torch.zeros(4, dtype=torch.bool)
4117-
psteps[[1, 3]] = True
4118-
td.set("_step", psteps)
4119-
4120-
td.set("action", penv.full_action_spec[penv.action_key].one())
4121-
td = penv.step(td)
4122-
assert (td[0].get("next") == 0).all()
4123-
assert (td[1].get("next") != 0).any()
4124-
assert (td[2].get("next") == 0).all()
4125-
assert (td[3].get("next") != 0).any()
4115+
try:
4116+
td = penv.reset()
4117+
psteps = torch.zeros(4, dtype=torch.bool)
4118+
psteps[[1, 3]] = True
4119+
td.set("_step", psteps)
4120+
4121+
td.set("action", penv.full_action_spec[penv.action_key].one())
4122+
td = penv.step(td)
4123+
assert_allclose_td(td[0].get("next"), td[0], intersection=True)
4124+
assert (td[1].get("next") != 0).any()
4125+
assert_allclose_td(td[2].get("next"), td[2], intersection=True)
4126+
assert (td[3].get("next") != 0).any()
4127+
finally:
4128+
penv.close()
4129+
del penv
41264130

41274131
@pytest.mark.parametrize("use_buffers", [False, True])
41284132
def test_parallel_partial_step_and_maybe_reset(
@@ -4135,17 +4139,21 @@ def test_parallel_partial_step_and_maybe_reset(
41354139
use_buffers=use_buffers,
41364140
device=device,
41374141
)
4138-
td = penv.reset()
4139-
psteps = torch.zeros(4, dtype=torch.bool)
4140-
psteps[[1, 3]] = True
4141-
td.set("_step", psteps)
4142-
4143-
td.set("action", penv.full_action_spec[penv.action_key].one())
4144-
td, tdreset = penv.step_and_maybe_reset(td)
4145-
assert (td[0].get("next") == 0).all()
4146-
assert (td[1].get("next") != 0).any()
4147-
assert (td[2].get("next") == 0).all()
4148-
assert (td[3].get("next") != 0).any()
4142+
try:
4143+
td = penv.reset()
4144+
psteps = torch.zeros(4, dtype=torch.bool)
4145+
psteps[[1, 3]] = True
4146+
td.set("_step", psteps)
4147+
4148+
td.set("action", penv.full_action_spec[penv.action_key].one())
4149+
td, tdreset = penv.step_and_maybe_reset(td)
4150+
assert_allclose_td(td[0].get("next"), td[0], intersection=True)
4151+
assert (td[1].get("next") != 0).any()
4152+
assert_allclose_td(td[2].get("next"), td[2], intersection=True)
4153+
assert (td[3].get("next") != 0).any()
4154+
finally:
4155+
penv.close()
4156+
del penv
41494157

41504158
@pytest.mark.parametrize("use_buffers", [False, True])
41514159
def test_serial_partial_steps(self, use_buffers, device, env_device):
@@ -4156,17 +4164,21 @@ def test_serial_partial_steps(self, use_buffers, device, env_device):
41564164
use_buffers=use_buffers,
41574165
device=device,
41584166
)
4159-
td = penv.reset()
4160-
psteps = torch.zeros(4, dtype=torch.bool)
4161-
psteps[[1, 3]] = True
4162-
td.set("_step", psteps)
4163-
4164-
td.set("action", penv.full_action_spec[penv.action_key].one())
4165-
td = penv.step(td)
4166-
assert (td[0].get("next") == 0).all()
4167-
assert (td[1].get("next") != 0).any()
4168-
assert (td[2].get("next") == 0).all()
4169-
assert (td[3].get("next") != 0).any()
4167+
try:
4168+
td = penv.reset()
4169+
psteps = torch.zeros(4, dtype=torch.bool)
4170+
psteps[[1, 3]] = True
4171+
td.set("_step", psteps)
4172+
4173+
td.set("action", penv.full_action_spec[penv.action_key].one())
4174+
td = penv.step(td)
4175+
assert_allclose_td(td[0].get("next"), td[0], intersection=True)
4176+
assert (td[1].get("next") != 0).any()
4177+
assert_allclose_td(td[2].get("next"), td[2], intersection=True)
4178+
assert (td[3].get("next") != 0).any()
4179+
finally:
4180+
penv.close()
4181+
del penv
41704182

41714183
@pytest.mark.parametrize("use_buffers", [False, True])
41724184
def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_device):
@@ -4184,9 +4196,9 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
41844196

41854197
td.set("action", penv.full_action_spec[penv.action_key].one())
41864198
td = penv.step(td)
4187-
assert (td[0].get("next") == 0).all()
4199+
assert_allclose_td(td[0].get("next"), td[0], intersection=True)
41884200
assert (td[1].get("next") != 0).any()
4189-
assert (td[2].get("next") == 0).all()
4201+
assert_allclose_td(td[2].get("next"), td[2], intersection=True)
41904202
assert (td[3].get("next") != 0).any()
41914203

41924204

0 commit comments

Comments
 (0)