Skip to content

Commit eec400d

Browse files
committed
[Feature] Transform for partial steps
ghstack-source-id: cd6c967ac6d793e078cac90c340942f23ffb16f4 Pull Request resolved: #2777
1 parent b27ee6d commit eec400d

File tree

6 files changed

+336
-16
lines changed

6 files changed

+336
-16
lines changed

test/test_transforms.py

+154
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
CenterCrop,
5959
ClipTransform,
6060
Compose,
61+
ConditionalSkip,
6162
Crop,
6263
DeviceCastTransform,
6364
DiscreteActionProjection,
@@ -13451,6 +13452,159 @@ def test_composite_reward_spec(self) -> None:
1345113452
assert transform.transform_reward_spec(reward_spec) == expected_reward_spec
1345213453

1345313454

13455+
class TestConditionalSkip(TransformBase):
13456+
@pytest.mark.parametrize("bwad", [False, True])
13457+
def test_single_trans_env_check(self, bwad):
13458+
env = CountingEnv()
13459+
base_env = TransformedEnv(env, StepCounter(step_count_key="other_count"))
13460+
env = TransformedEnv(base_env, StepCounter(), auto_unwrap=False)
13461+
env = env.append_transform(
13462+
ConditionalSkip(cond=lambda td: td["step_count"] % 2 == 1)
13463+
)
13464+
env.set_seed(0)
13465+
env.check_env_specs()
13466+
policy = lambda td: td.set("action", torch.ones((1,)))
13467+
r = env.rollout(10, policy, break_when_any_done=bwad)
13468+
assert (r["step_count"] == torch.arange(10).view(10, 1)).all()
13469+
assert (r["other_count"] == torch.arange(1, 11).view(10, 1) // 2).all()
13470+
13471+
@pytest.mark.parametrize("bwad", [False, True])
13472+
def test_serial_trans_env_check(self, bwad):
13473+
def make_env(i):
13474+
env = CountingEnv()
13475+
base_env = TransformedEnv(env, StepCounter(step_count_key="other_count"))
13476+
return TransformedEnv(
13477+
base_env,
13478+
Compose(
13479+
StepCounter(),
13480+
ConditionalSkip(cond=lambda td, i=i: (td["step_count"] % 2 == i)),
13481+
),
13482+
auto_unwrap=False,
13483+
)
13484+
13485+
env = SerialEnv(2, [partial(make_env, i=0), partial(make_env, i=1)])
13486+
env.check_env_specs()
13487+
policy = lambda td: td.set("action", torch.ones((2, 1)))
13488+
r = env.rollout(10, policy, break_when_any_done=bwad)
13489+
assert (r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1)).all()
13490+
assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all()
13491+
assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all()
13492+
13493+
@pytest.mark.parametrize("bwad", [False, True])
13494+
def test_parallel_trans_env_check(self, bwad):
13495+
def make_env(i):
13496+
env = CountingEnv()
13497+
base_env = TransformedEnv(env, StepCounter(step_count_key="other_count"))
13498+
return TransformedEnv(
13499+
base_env,
13500+
Compose(
13501+
StepCounter(),
13502+
ConditionalSkip(cond=lambda td, i=i: (td["step_count"] % 2 == i)),
13503+
),
13504+
auto_unwrap=False,
13505+
)
13506+
13507+
env = ParallelEnv(
13508+
2, [partial(make_env, i=0), partial(make_env, i=1)], mp_start_method=mp_ctx
13509+
)
13510+
try:
13511+
env.check_env_specs()
13512+
policy = lambda td: td.set("action", torch.ones((2, 1)))
13513+
r = env.rollout(10, policy, break_when_any_done=bwad)
13514+
assert (
13515+
r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1)
13516+
).all()
13517+
assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all()
13518+
assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all()
13519+
finally:
13520+
env.close()
13521+
del env
13522+
13523+
@pytest.mark.parametrize("bwad", [False, True])
13524+
def test_trans_serial_env_check(self, bwad):
13525+
def make_env():
13526+
env = CountingEnv(max_steps=100)
13527+
base_env = TransformedEnv(env, StepCounter(step_count_key="other_count"))
13528+
return base_env
13529+
13530+
base_env = SerialEnv(2, [make_env, make_env])
13531+
13532+
def cond(td):
13533+
sc = td["step_count"] + torch.tensor([[0], [1]])
13534+
return sc.squeeze() % 2 == 0
13535+
13536+
env = TransformedEnv(base_env, Compose(StepCounter(), ConditionalSkip(cond)))
13537+
env.check_env_specs()
13538+
policy = lambda td: td.set("action", torch.ones((2, 1)))
13539+
r = env.rollout(10, policy, break_when_any_done=bwad)
13540+
assert (r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1)).all()
13541+
assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all()
13542+
assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all()
13543+
13544+
@pytest.mark.parametrize("bwad", [True, False])
13545+
@pytest.mark.parametrize("buffers", [True, False])
13546+
def test_trans_parallel_env_check(self, bwad, buffers):
13547+
def make_env():
13548+
env = CountingEnv(max_steps=100)
13549+
base_env = TransformedEnv(env, StepCounter(step_count_key="other_count"))
13550+
return base_env
13551+
13552+
base_env = ParallelEnv(
13553+
2, [make_env, make_env], mp_start_method=mp_ctx, use_buffers=buffers
13554+
)
13555+
try:
13556+
13557+
def cond(td):
13558+
sc = td["step_count"] + torch.tensor([[0], [1]])
13559+
return sc.squeeze() % 2 == 0
13560+
13561+
env = TransformedEnv(
13562+
base_env, Compose(StepCounter(), ConditionalSkip(cond))
13563+
)
13564+
env.check_env_specs()
13565+
policy = lambda td: td.set("action", torch.ones((2, 1)))
13566+
r = env.rollout(10, policy, break_when_any_done=bwad)
13567+
assert (
13568+
r["step_count"] == torch.arange(10).view(10, 1).expand(2, 10, 1)
13569+
).all()
13570+
assert (r["other_count"][0] == torch.arange(0, 10).view(10, 1) // 2).all()
13571+
assert (r["other_count"][1] == torch.arange(1, 11).view(10, 1) // 2).all()
13572+
finally:
13573+
base_env.close()
13574+
del base_env
13575+
13576+
def test_transform_no_env(self):
13577+
t = ConditionalSkip(lambda td: torch.arange(td.numel()).view(td.shape) % 2 == 0)
13578+
assert not t._inv_call(TensorDict())["_step"]
13579+
assert t._inv_call(TensorDict())["_step"].shape == ()
13580+
assert t._inv_call(TensorDict(batch_size=(2, 3)))["_step"].shape == (2, 3)
13581+
13582+
def test_transform_compose(self):
13583+
t = Compose(
13584+
ConditionalSkip(lambda td: torch.arange(td.numel()).view(td.shape) % 2 == 0)
13585+
)
13586+
assert not t._inv_call(TensorDict())["_step"]
13587+
assert t._inv_call(TensorDict())["_step"].shape == ()
13588+
assert t._inv_call(TensorDict(batch_size=(2, 3)))["_step"].shape == (2, 3)
13589+
13590+
def test_transform_env(self):
13591+
# tested above
13592+
return
13593+
13594+
def test_transform_model(self):
13595+
t = Compose(
13596+
ConditionalSkip(lambda td: torch.arange(td.numel()).view(td.shape) % 2 == 0)
13597+
)
13598+
with pytest.raises(NotImplementedError):
13599+
t(TensorDict())["_step"]
13600+
13601+
def test_transform_rb(self):
13602+
return
13603+
13604+
def test_transform_inverse(self):
13605+
return
13606+
13607+
1345413608
if __name__ == "__main__":
1345513609
args, unknown = argparse.ArgumentParser().parse_known_args()
1345613610
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
CenterCrop,
5757
ClipTransform,
5858
Compose,
59+
ConditionalSkip,
5960
Crop,
6061
DeviceCastTransform,
6162
DiscreteActionProjection,

torchrl/envs/batched_envs.py

+25-4
Original file line numberDiff line numberDiff line change
@@ -1089,7 +1089,7 @@ def _step(
10891089
self,
10901090
tensordict: TensorDict,
10911091
) -> TensorDict:
1092-
partial_steps = tensordict.get("_step", None)
1092+
partial_steps = tensordict.get("_step")
10931093
tensordict_save = tensordict
10941094
if partial_steps is not None and partial_steps.all():
10951095
partial_steps = None
@@ -1164,6 +1164,10 @@ def select_and_clone(name, tensor):
11641164

11651165
if partial_steps is not None:
11661166
result = out.new_zeros(tensordict_save.shape)
1167+
# Copy the observation data from the previous step as placeholder
1168+
result.update(
1169+
tensordict_save.select(*result.keys(True, True), strict=False).clone()
1170+
)
11671171
result[partial_steps] = out
11681172
return result
11691173

@@ -1529,6 +1533,9 @@ def _step_and_maybe_reset_no_buffers(
15291533
)
15301534
if partial_steps is not None:
15311535
result = out.new_zeros(tensordict_save.shape)
1536+
result.update(
1537+
tensordict_save.select(*result.keys(True, True), strict=False).clone()
1538+
)
15321539
result[partial_steps] = out
15331540
return result
15341541
return out
@@ -1543,7 +1550,7 @@ def step_and_maybe_reset(
15431550
# return self._step_and_maybe_reset_no_buffers(tensordict)
15441551
return super().step_and_maybe_reset(tensordict)
15451552

1546-
partial_steps = tensordict.get("_step", None)
1553+
partial_steps = tensordict.get("_step")
15471554
tensordict_save = tensordict
15481555
if partial_steps is not None and partial_steps.all():
15491556
partial_steps = None
@@ -1661,6 +1668,14 @@ def step_and_maybe_reset(
16611668
if partial_steps is not None:
16621669
result = tensordict.new_zeros(tensordict_save.shape)
16631670
result_ = tensordict_.new_zeros(tensordict_save.shape)
1671+
1672+
result.update(
1673+
tensordict_save.select(*result.keys(True, True), strict=False).clone()
1674+
)
1675+
result_.update(
1676+
tensordict_save.select(*result_.keys(True, True), strict=False).clone()
1677+
)
1678+
16641679
result[partial_steps] = tensordict
16651680
result_[partial_steps] = tensordict_
16661681
return result, result_
@@ -1700,7 +1715,7 @@ def _wait_for_workers(self, workers_range):
17001715
def _step_no_buffers(
17011716
self, tensordict: TensorDictBase
17021717
) -> Tuple[TensorDictBase, TensorDictBase]:
1703-
partial_steps = tensordict.get("_step", None)
1718+
partial_steps = tensordict.get("_step")
17041719
tensordict_save = tensordict
17051720
if partial_steps is not None and partial_steps.all():
17061721
partial_steps = None
@@ -1727,6 +1742,9 @@ def _step_no_buffers(
17271742
out = out.to(self.device, non_blocking=self.non_blocking)
17281743
if partial_steps is not None:
17291744
result = out.new_zeros(tensordict_save.shape)
1745+
result.update(
1746+
tensordict_save.select(*result.keys(True, True), strict=False).clone()
1747+
)
17301748
result[partial_steps] = out
17311749
return result
17321750
return out
@@ -1744,7 +1762,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
17441762
# and this transform overrides an observation key (eg, CatFrames)
17451763
# the shape, dtype or device may not necessarily match and writing
17461764
# the value in-place will fail.
1747-
partial_steps = tensordict.get("_step", None)
1765+
partial_steps = tensordict.get("_step")
17481766
tensordict_save = tensordict
17491767
if partial_steps is not None and partial_steps.all():
17501768
partial_steps = None
@@ -1875,6 +1893,9 @@ def select_and_clone(name, tensor):
18751893
self._sync_w2m()
18761894
if partial_steps is not None:
18771895
result = out.new_zeros(tensordict_save.shape)
1896+
result.update(
1897+
tensordict_save.select(*result.keys(True, True), strict=False).clone()
1898+
)
18781899
result[partial_steps] = out
18791900
return result
18801901
return out

torchrl/envs/common.py

+28-9
Original file line numberDiff line numberDiff line change
@@ -1933,6 +1933,17 @@ def state_spec_unbatched(self, spec: Composite):
19331933
spec = spec.expand(self.batch_size + spec.shape)
19341934
self.state_spec = spec
19351935

1936+
def _skip_tensordict(self, tensordict):
1937+
# Creates a "skip" tensordict, ie a placeholder for when a step is skipped
1938+
next_tensordict = self.full_done_spec.zero()
1939+
next_tensordict.update(self.full_observation_spec.zero())
1940+
next_tensordict.update(self.full_reward_spec.zero())
1941+
# Copy the data from tensordict in `next`
1942+
next_tensordict.update(
1943+
tensordict.select(*next_tensordict.keys(True, True), strict=False)
1944+
)
1945+
return next_tensordict
1946+
19361947
def step(self, tensordict: TensorDictBase) -> TensorDictBase:
19371948
"""Makes a step in the environment.
19381949
@@ -1953,25 +1964,33 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
19531964
"""
19541965
# sanity check
19551966
self._assert_tensordict_shape(tensordict)
1956-
partial_steps = None
1967+
partial_steps = tensordict.pop("_step", None)
19571968

1958-
if not self.batch_locked:
1959-
# Batched envs have their own way of dealing with this - batched envs that are not batched-locked may fail here
1960-
partial_steps = tensordict.get("_step", None)
1961-
if partial_steps is not None:
1969+
next_tensordict = None
1970+
1971+
if partial_steps is not None:
1972+
if not self.batch_locked:
1973+
# Batched envs have their own way of dealing with this - batched envs that are not batched-locked may fail here
19621974
if partial_steps.all():
19631975
partial_steps = None
19641976
else:
19651977
tensordict_batch_size = tensordict.batch_size
19661978
partial_steps = partial_steps.view(tensordict_batch_size)
19671979
tensordict = tensordict[partial_steps]
1968-
else:
1980+
else:
1981+
if not partial_steps.any():
1982+
next_tensordict = self._skip_tensordic(tensordict)
1983+
else:
1984+
# trust that the _step can handle this!
1985+
tensordict.set("_step", partial_steps)
1986+
19691987
tensordict_batch_size = self.batch_size
19701988

19711989
next_preset = tensordict.get("next", None)
19721990

1973-
next_tensordict = self._step(tensordict)
1974-
next_tensordict = self._step_proc_data(next_tensordict)
1991+
if next_tensordict is None:
1992+
next_tensordict = self._step(tensordict)
1993+
next_tensordict = self._step_proc_data(next_tensordict)
19751994
if next_preset is not None:
19761995
# tensordict could already have a "next" key
19771996
# this could be done more efficiently by not excluding but just passing
@@ -1980,7 +1999,7 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
19801999
next_preset.exclude(*next_tensordict.keys(True, True))
19812000
)
19822001
tensordict.set("next", next_tensordict)
1983-
if partial_steps is not None:
2002+
if partial_steps is not None and tensordict_batch_size != self.batch_size:
19842003
result = tensordict.new_zeros(tensordict_batch_size)
19852004
result[partial_steps] = tensordict
19862005
return result

torchrl/envs/transforms/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
CenterCrop,
2121
ClipTransform,
2222
Compose,
23+
ConditionalSkip,
2324
Crop,
2425
DeviceCastTransform,
2526
DiscreteActionProjection,

0 commit comments

Comments
 (0)