Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Transform for partial steps #2777

Merged
merged 15 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 0 additions & 32 deletions .github/workflows/test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,38 +53,6 @@ jobs:
## setup_env.sh
bash .github/unittest/linux/scripts/run_all.sh

tests-cpu-oldget:
# Tests that TD_GET_DEFAULTS_TO_NONE=0 works fine as this will be the default for TD up to 0.7
strategy:
matrix:
python_version: ["3.12"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
runner: linux.12xlarge
repository: pytorch/rl
docker-image: "nvidia/cuda:12.2.0-devel-ubuntu22.04"
timeout: 90
script: |
if [[ "${{ github.ref }}" =~ release/* ]]; then
export RELEASE=1
export TORCH_VERSION=stable
else
export RELEASE=0
export TORCH_VERSION=nightly
fi
export TD_GET_DEFAULTS_TO_NONE=0

# Set env vars from matrix
export PYTHON_VERSION=${{ matrix.python_version }}
export CU_VERSION="cpu"

echo "PYTHON_VERSION: $PYTHON_VERSION"
echo "CU_VERSION: $CU_VERSION"

## setup_env.sh
bash .github/unittest/linux/scripts/run_all.sh

tests-gpu:
strategy:
matrix:
Expand Down
76 changes: 76 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,81 @@ provides more information on how to design a custom environment from scratch.
GymLikeEnv
EnvMetaData

Partial steps and partial resets
--------------------------------

TorchRL allows environments to reset some but not all the environments, or run a step in one but not all environments.
If there is only one environment in the batch, then a partial reset / step is also allowed with the behavior detailed
below.

Batching environments and locking the batch
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. _ref_batch_locked:

Before detailing what partial resets and partial steps do, we must distinguish cases where an environment has
a batch size of its own (mostly stateful environments) or when the environment is just a mere module that, given an
input of arbitrary size, batches the operations over all elements (mostly stateless environments).

This is controlled via the :attr:`~torchrl.envs.batch_locked` attribute: a batch-locked environment requires all input
tensordicts to have the same batch-size as the env's. Typical examples of these environments are
:class:`~torchrl.envs.GymEnv` and related. Batch-unlocked envs are by contrast allowed to work with any input size.
Notable examples are :class:`~torchrl.envs.BraxEnv` or :class:`~torchrl.envs.JumanjiEnv`.

Executing partial steps in a batch-unlocked environment is straightforward: one just needs to mask the part of the
tensordict that does not need to be executed, pass the other part to `step` and merge the results with the previous
input.

Batched environments (:class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv`) can also deal with
partial steps easily, they just pass the actions to the sub-environments that are required to be executed.

In all other cases, TorchRL assumes that the environment handles the partial steps correctly.

.. warning:: This means that custom environments may silently run the non-required steps as there is no way for torchrl
to control what happens within the `_step` method!

Partial Steps
~~~~~~~~~~~~~

.. _ref_partial_steps:

Partial steps are controlled via the temporary key `"_step"` which points to a boolean mask of the
size of the tensordict that holds it. The classes armed to deal with this are:

- Batched environments: :class:`~torchrl.envs.ParallelEnv` and :class:`~torchrl.envs.SerialEnv` will dispatch the
action to and only to the environments where `"_step"` is `True`;
- Batch-unlocked environments;
- Unbatched environments (i.e., environments without batch size). In these environments, the :meth:`~torchrl.envs.EnvBase.step`
method will first look for a `"_step"` entry and, if present, act accordingly.
If a :class:`~torchrl.envs.Transform` instance passes a `"_step"` entry to the tensordict, it is also captured by
:class:`~torchrl.envs.TransformedEnv`'s own `_step` method which will skip the `base_env.step` as well as any further
transformation.

When dealing with partial steps, the strategy is always to use the step output and mask missing values with the previous
content of the input tensordict, if present, or a `0`-valued tensor if the tensor cannot be found. This means that
if the input tensordict does not contain all the previous observations, then the output tensordict will be 0-valued for
all the non-stepped elements. Within batched environments, data collectors and rollouts utils, this is an issue that
is not observed because these classes handle the passing of data properly.

Partial steps are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_all_done` is `True`,
as the environments with a `True` done state will need to be skipped during calls to `_step`.

The :class:`~torchrl.envs.ConditionalSkip` transform allows you to programmatically ask for (partial) step skips.

Partial Resets
~~~~~~~~~~~~~~

.. _ref_partial_resets:

Partial resets work pretty much like partial steps, but with the `"_reset"` entry.

The same restrictions of partial steps apply to partial resets.

Likewise, partial resets are an essential feature of :meth:`~torchrl.envs.EnvBase.rollout` when `break_when_any_done` is `True`,
as the environments with a `True` done state will need to be reset, but not others.

See te following paragraph for a deep dive in partial resets within batched and vectorized environments.

Vectorized envs
---------------

Expand Down Expand Up @@ -886,6 +961,7 @@ to be able to create this other composition:
CenterCrop
ClipTransform
Compose
ConditionalSkip
Crop
DTypeCastTransform
DeviceCastTransform
Expand Down
11 changes: 10 additions & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,8 @@ def _step(


class ContinuousActionVecMockEnv(_MockEnv):
adapt_dtype: bool = True

@classmethod
def __new__(
cls,
Expand Down Expand Up @@ -635,7 +637,14 @@ def _step(
while done.shape != tensordict.shape:
done = done.any(-1)
done = reward = done.unsqueeze(-1)
tensordict.set("reward", reward.to(torch.get_default_dtype()))
tensordict.set(
"reward",
reward.to(
self.reward_spec.dtype
if self.adapt_dtype
else torch.get_default_dtype()
).expand(self.reward_spec.shape),
)
tensordict.set("done", done)
tensordict.set("terminated", done)
return tensordict
Expand Down
83 changes: 48 additions & 35 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@
class TestEnvBase:
def test_run_type_checks(self):
env = ContinuousActionVecMockEnv()
env.adapt_dtype = False
env._run_type_checks = False
check_env_specs(env)
env._run_type_checks = True
Expand Down Expand Up @@ -4112,17 +4113,21 @@ def test_parallel_partial_steps(
use_buffers=use_buffers,
device=device,
)
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool)
psteps[[1, 3]] = True
td.set("_step", psteps)

td.set("action", penv.full_action_spec[penv.action_key].one())
td = penv.step(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
assert (td[2].get("next") == 0).all()
assert (td[3].get("next") != 0).any()
try:
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool)
psteps[[1, 3]] = True
td.set("_step", psteps)

td.set("action", penv.full_action_spec[penv.action_key].one())
td = penv.step(td)
assert_allclose_td(td[0].get("next"), td[0], intersection=True)
assert (td[1].get("next") != 0).any()
assert_allclose_td(td[2].get("next"), td[2], intersection=True)
assert (td[3].get("next") != 0).any()
finally:
penv.close()
del penv

@pytest.mark.parametrize("use_buffers", [False, True])
def test_parallel_partial_step_and_maybe_reset(
Expand All @@ -4135,17 +4140,21 @@ def test_parallel_partial_step_and_maybe_reset(
use_buffers=use_buffers,
device=device,
)
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool)
psteps[[1, 3]] = True
td.set("_step", psteps)

td.set("action", penv.full_action_spec[penv.action_key].one())
td, tdreset = penv.step_and_maybe_reset(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
assert (td[2].get("next") == 0).all()
assert (td[3].get("next") != 0).any()
try:
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool, device=td.get("done").device)
psteps[[1, 3]] = True
td.set("_step", psteps)

td.set("action", penv.full_action_spec[penv.action_key].one())
td, tdreset = penv.step_and_maybe_reset(td)
assert_allclose_td(td[0].get("next"), td[0], intersection=True)
assert (td[1].get("next") != 0).any()
assert_allclose_td(td[2].get("next"), td[2], intersection=True)
assert (td[3].get("next") != 0).any()
finally:
penv.close()
del penv

@pytest.mark.parametrize("use_buffers", [False, True])
def test_serial_partial_steps(self, use_buffers, device, env_device):
Expand All @@ -4156,17 +4165,21 @@ def test_serial_partial_steps(self, use_buffers, device, env_device):
use_buffers=use_buffers,
device=device,
)
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool)
psteps[[1, 3]] = True
td.set("_step", psteps)

td.set("action", penv.full_action_spec[penv.action_key].one())
td = penv.step(td)
assert (td[0].get("next") == 0).all()
assert (td[1].get("next") != 0).any()
assert (td[2].get("next") == 0).all()
assert (td[3].get("next") != 0).any()
try:
td = penv.reset()
psteps = torch.zeros(4, dtype=torch.bool)
psteps[[1, 3]] = True
td.set("_step", psteps)

td.set("action", penv.full_action_spec[penv.action_key].one())
td = penv.step(td)
assert_allclose_td(td[0].get("next"), td[0], intersection=True)
assert (td[1].get("next") != 0).any()
assert_allclose_td(td[2].get("next"), td[2], intersection=True)
assert (td[3].get("next") != 0).any()
finally:
penv.close()
del penv

@pytest.mark.parametrize("use_buffers", [False, True])
def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_device):
Expand All @@ -4184,9 +4197,9 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi

td.set("action", penv.full_action_spec[penv.action_key].one())
td = penv.step(td)
assert (td[0].get("next") == 0).all()
assert_allclose_td(td[0].get("next"), td[0], intersection=True)
assert (td[1].get("next") != 0).any()
assert (td[2].get("next") == 0).all()
assert_allclose_td(td[2].get("next"), td[2], intersection=True)
assert (td[3].get("next") != 0).any()


Expand Down
Loading
Loading