Skip to content

Commit 74f6075

Browse files
committed
[Feature] MultiAction transform
ghstack-source-id: fb3940e Pull Request resolved: #2779
1 parent 3f92dd9 commit 74f6075

File tree

9 files changed

+578
-46
lines changed

9 files changed

+578
-46
lines changed

docs/source/reference/envs.rst

+1
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,7 @@ to be able to create this other composition:
977977
InitTracker
978978
KLRewardTransform
979979
LineariseRewards
980+
MultiAction
980981
NoopResetEnv
981982
ObservationNorm
982983
ObservationTransform

test/mocking_classes.py

+94-15
Original file line numberDiff line numberDiff line change
@@ -358,13 +358,11 @@ def _step(self, tensordict):
358358
leading_batch_size = tensordict.shape if tensordict is not None else []
359359
self.counter += 1
360360
# We use tensordict.batch_size instead of self.batch_size since this method will also be used by MockBatchedUnLockedEnv
361-
n = (
362-
torch.full(
363-
[*leading_batch_size, *self.observation_spec["observation"].shape],
364-
self.counter,
365-
)
366-
.to(self.device)
367-
.to(torch.get_default_dtype())
361+
n = torch.full(
362+
[*leading_batch_size, *self.observation_spec["observation"].shape],
363+
self.counter,
364+
device=self.device,
365+
dtype=torch.get_default_dtype(),
368366
)
369367
done = self.counter >= self.max_val
370368
done = torch.full(
@@ -391,13 +389,11 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
391389
else:
392390
leading_batch_size = tensordict.shape if tensordict is not None else []
393391

394-
n = (
395-
torch.full(
396-
[*leading_batch_size, *self.observation_spec["observation"].shape],
397-
self.counter,
398-
)
399-
.to(self.device)
400-
.to(torch.get_default_dtype())
392+
n = torch.full(
393+
[*leading_batch_size, *self.observation_spec["observation"].shape],
394+
self.counter,
395+
device=self.device,
396+
dtype=torch.get_default_dtype(),
401397
)
402398
done = self.counter >= self.max_val
403399
done = torch.full(
@@ -417,7 +413,7 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
417413

418414

419415
class MockBatchedUnLockedEnv(MockBatchedLockedEnv):
420-
"""Mocks an env whose batch_size does not define the size of the output tensordict.
416+
"""Mocks an env which batch_size does not define the size of the output tensordict.
421417
422418
The size of the output tensordict is defined by the input tensordict itself.
423419
@@ -433,6 +429,89 @@ def __new__(cls, *args, **kwargs):
433429
return super().__new__(cls, *args, _batch_locked=False, **kwargs)
434430

435431

432+
class StateLessCountingEnv(EnvBase):
433+
def __init__(self):
434+
self.observation_spec = Composite(
435+
count=Unbounded((1,), dtype=torch.int32),
436+
max_count=Unbounded((1,), dtype=torch.int32),
437+
)
438+
self.full_action_spec = Composite(
439+
action=Unbounded((1,), dtype=torch.int32),
440+
)
441+
self.full_done_spec = Composite(
442+
done=Unbounded((1,), dtype=torch.bool),
443+
termindated=Unbounded((1,), dtype=torch.bool),
444+
truncated=Unbounded((1,), dtype=torch.bool),
445+
)
446+
self.reward_spec = Composite(reward=Unbounded((1,), dtype=torch.float))
447+
super().__init__()
448+
self._batch_locked = False
449+
450+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
451+
452+
max_count = None
453+
count = None
454+
if tensordict is not None:
455+
max_count = tensordict.get("max_count")
456+
count = tensordict.get("count")
457+
tensordict = TensorDict(
458+
batch_size=tensordict.batch_size, device=tensordict.device
459+
)
460+
shape = tensordict.batch_size
461+
else:
462+
shape = ()
463+
tensordict = TensorDict(device=self.device)
464+
tensordict.update(
465+
TensorDict(
466+
count=torch.zeros(
467+
(
468+
*shape,
469+
1,
470+
),
471+
dtype=torch.int32,
472+
)
473+
if count is None
474+
else count,
475+
max_count=torch.randint(
476+
10,
477+
20,
478+
(
479+
*shape,
480+
1,
481+
),
482+
dtype=torch.int32,
483+
)
484+
if max_count is None
485+
else max_count,
486+
**self.done_spec.zero(shape),
487+
**self.full_reward_spec.zero(shape),
488+
)
489+
)
490+
return tensordict
491+
492+
def _step(
493+
self,
494+
tensordict: TensorDictBase,
495+
) -> TensorDictBase:
496+
action = tensordict["action"]
497+
count = tensordict["count"] + action
498+
terminated = done = count >= tensordict["max_count"]
499+
truncated = torch.zeros_like(done)
500+
return TensorDict(
501+
count=count,
502+
max_count=tensordict["max_count"],
503+
done=done,
504+
terminated=terminated,
505+
truncated=truncated,
506+
reward=self.reward_spec.zero(tensordict.shape),
507+
batch_size=tensordict.batch_size,
508+
device=tensordict.device,
509+
)
510+
511+
def _set_seed(self, seed: Optional[int]):
512+
...
513+
514+
436515
class DiscreteActionVecMockEnv(_MockEnv):
437516
@classmethod
438517
def __new__(

test/test_transforms.py

+195
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
Hash,
7878
InitTracker,
7979
LineariseRewards,
80+
MultiAction,
8081
MultiStepTransform,
8182
NoopResetEnv,
8283
ObservationNorm,
@@ -156,6 +157,7 @@
156157
MultiKeyCountingEnv,
157158
MultiKeyCountingEnvPolicy,
158159
NestedCountingEnv,
160+
StateLessCountingEnv,
159161
)
160162
else:
161163
from _utils_internal import ( # noqa
@@ -184,6 +186,7 @@
184186
MultiKeyCountingEnv,
185187
MultiKeyCountingEnvPolicy,
186188
NestedCountingEnv,
189+
StateLessCountingEnv,
187190
)
188191

189192
IS_WIN = platform == "win32"
@@ -13664,6 +13667,198 @@ def test_transform_inverse(self):
1366413667
return
1366513668

1366613669

13670+
class TestMultiAction(TransformBase):
13671+
@pytest.mark.parametrize("bwad", [False, True])
13672+
def test_single_trans_env_check(self, bwad):
13673+
base_env = CountingEnv(max_steps=10)
13674+
env = TransformedEnv(
13675+
base_env,
13676+
Compose(
13677+
StepCounter(step_count_key="before_count"),
13678+
MultiAction(),
13679+
StepCounter(step_count_key="after_count"),
13680+
),
13681+
)
13682+
env.check_env_specs()
13683+
13684+
def policy(td):
13685+
# 3 action per step
13686+
td["action"] = torch.ones(3, 1)
13687+
return td
13688+
13689+
r = env.rollout(10, policy)
13690+
assert r["action"].shape == (4, 3, 1)
13691+
assert r["next", "done"].any()
13692+
assert r["next", "done"][-1].all()
13693+
assert (r["observation"][0] == 0).all()
13694+
assert (r["next", "observation"][0] == 3).all()
13695+
assert (r["next", "observation"][-1] == 11).all()
13696+
# Check that before_count is incremented but not after_count
13697+
assert r["before_count"].max() == 9
13698+
assert r["after_count"].max() == 3
13699+
13700+
def _batched_trans_env_check(self, cls, bwad, within):
13701+
if within:
13702+
13703+
def make_env(i):
13704+
base_env = CountingEnv(max_steps=i)
13705+
env = TransformedEnv(
13706+
base_env,
13707+
Compose(
13708+
StepCounter(step_count_key="before_count"),
13709+
MultiAction(),
13710+
StepCounter(step_count_key="after_count"),
13711+
),
13712+
)
13713+
return env
13714+
13715+
env = cls(2, [partial(make_env, i=10), partial(make_env, i=20)])
13716+
else:
13717+
base_env = cls(
13718+
2,
13719+
[
13720+
partial(CountingEnv, max_steps=10),
13721+
partial(CountingEnv, max_steps=20),
13722+
],
13723+
)
13724+
env = TransformedEnv(
13725+
base_env,
13726+
Compose(
13727+
StepCounter(step_count_key="before_count"),
13728+
MultiAction(),
13729+
StepCounter(step_count_key="after_count"),
13730+
),
13731+
)
13732+
13733+
try:
13734+
env.check_env_specs()
13735+
13736+
def policy(td):
13737+
# 3 action per step
13738+
td["action"] = torch.ones(2, 3, 1)
13739+
return td
13740+
13741+
r = env.rollout(10, policy, break_when_any_done=bwad)
13742+
# r0
13743+
r0 = r[0]
13744+
if bwad:
13745+
assert r["action"].shape == (2, 4, 3, 1)
13746+
else:
13747+
assert r["action"].shape == (2, 10, 3, 1)
13748+
assert r0["next", "done"].any()
13749+
if bwad:
13750+
assert r0["next", "done"][-1].all()
13751+
else:
13752+
assert r0["next", "done"].sum() == 2
13753+
13754+
assert (r0["observation"][0] == 0).all()
13755+
assert (r0["next", "observation"][0] == 3).all()
13756+
if bwad:
13757+
assert (r0["next", "observation"][-1] == 11).all()
13758+
else:
13759+
assert (r0["next", "observation"][-1] == 6).all(), r0[
13760+
"next", "observation"
13761+
]
13762+
# Check that before_count is incremented but not after_count
13763+
assert r0["before_count"].max() == 9
13764+
assert r0["after_count"].max() == 3
13765+
# r1
13766+
r1 = r[1]
13767+
if bwad:
13768+
assert not r1["next", "done"].any()
13769+
else:
13770+
assert r1["next", "done"].any()
13771+
assert r1["next", "done"].sum() == 1
13772+
assert (r1["observation"][0] == 0).all()
13773+
assert (r1["next", "observation"][0] == 3).all()
13774+
if bwad:
13775+
# r0 cannot go above 11 but r1 can - so we see a 12 because one more step was done
13776+
assert (r1["next", "observation"][-1] == 12).all()
13777+
else:
13778+
assert (r1["next", "observation"][-1] == 9).all()
13779+
# Check that before_count is incremented but not after_count
13780+
if bwad:
13781+
assert r1["before_count"].max() == 9
13782+
assert r1["after_count"].max() == 3
13783+
else:
13784+
assert r1["before_count"].max() == 18
13785+
assert r1["after_count"].max() == 6
13786+
finally:
13787+
env.close()
13788+
13789+
@pytest.mark.parametrize("bwad", [False, True])
13790+
def test_serial_trans_env_check(self, bwad):
13791+
self._batched_trans_env_check(SerialEnv, bwad, within=True)
13792+
13793+
@pytest.mark.parametrize("bwad", [False, True])
13794+
def test_parallel_trans_env_check(self, bwad):
13795+
self._batched_trans_env_check(
13796+
partial(ParallelEnv, mp_start_method=mp_ctx), bwad, within=True
13797+
)
13798+
13799+
@pytest.mark.parametrize("bwad", [False, True])
13800+
def test_trans_serial_env_check(self, bwad):
13801+
self._batched_trans_env_check(SerialEnv, bwad, within=False)
13802+
13803+
@pytest.mark.parametrize("bwad", [True, False])
13804+
@pytest.mark.parametrize("buffers", [True, False])
13805+
def test_trans_parallel_env_check(self, bwad, buffers):
13806+
self._batched_trans_env_check(
13807+
partial(ParallelEnv, use_buffers=buffers, mp_start_method=mp_ctx),
13808+
bwad,
13809+
within=False,
13810+
)
13811+
13812+
def test_transform_no_env(self):
13813+
...
13814+
13815+
def test_transform_compose(self):
13816+
...
13817+
13818+
@pytest.mark.parametrize("bwad", [True, False])
13819+
def test_transform_env(self, bwad):
13820+
# tests stateless (batch-unlocked) envs
13821+
torch.manual_seed(0)
13822+
env = StateLessCountingEnv()
13823+
13824+
def policy(td):
13825+
td["action"] = torch.ones(td.shape + (1,))
13826+
return td
13827+
13828+
r = env.rollout(
13829+
10,
13830+
tensordict=env.reset().expand(4),
13831+
auto_reset=False,
13832+
break_when_any_done=False,
13833+
policy=policy,
13834+
)
13835+
assert (r["count"] == torch.arange(10).expand(4, 10).view(4, 10, 1)).all()
13836+
td_reset = env.reset().expand(4).clone()
13837+
td_reset["max_count"] = torch.arange(4, 8).view(4, 1)
13838+
env = TransformedEnv(env, MultiAction())
13839+
13840+
def policy(td):
13841+
td["action"] = torch.ones(td.shape + (3,) + (1,))
13842+
return td
13843+
13844+
r = env.rollout(
13845+
20,
13846+
policy=policy,
13847+
auto_reset=False,
13848+
tensordict=td_reset,
13849+
break_when_any_done=bwad,
13850+
)
13851+
13852+
def test_transform_model(self):
13853+
...
13854+
13855+
def test_transform_rb(self):
13856+
return
13857+
13858+
def test_transform_inverse(self):
13859+
return
13860+
13861+
1366713862
if __name__ == "__main__":
1366813863
args, unknown = argparse.ArgumentParser().parse_known_args()
1366913864
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 commit comments

Comments
 (0)