Skip to content

Commit d537dcb

Browse files
committed
[Feature] EnvBase.auto_specs_
ghstack-source-id: 3296792 Pull Request resolved: #2601
1 parent 90572ac commit d537dcb

File tree

7 files changed

+267
-71
lines changed

7 files changed

+267
-71
lines changed

test/mocking_classes.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -1038,11 +1038,13 @@ def _step(
10381038
tensordict: TensorDictBase,
10391039
) -> TensorDictBase:
10401040
action = tensordict.get(self.action_key)
1041+
try:
1042+
device = self.full_action_spec[self.action_key].device
1043+
except KeyError:
1044+
device = self.device
10411045
self.count += action.to(
10421046
dtype=torch.int,
1043-
device=self.full_action_spec[self.action_key].device
1044-
if self.device is None
1045-
else self.device,
1047+
device=device if self.device is None else self.device,
10461048
)
10471049
tensordict = TensorDict(
10481050
source={
@@ -1275,8 +1277,10 @@ def __init__(
12751277
max_steps = torch.tensor(5)
12761278
if start_val is None:
12771279
start_val = torch.zeros((), dtype=torch.int32)
1278-
if not max_steps.shape == self.batch_size:
1279-
raise RuntimeError("batch_size and max_steps shape must match.")
1280+
if max_steps.shape != self.batch_size:
1281+
raise RuntimeError(
1282+
f"batch_size and max_steps shape must match. Got self.batch_size={self.batch_size} and max_steps.shape={max_steps.shape}."
1283+
)
12801284

12811285
self.max_steps = max_steps
12821286

test/test_env.py

+28
Original file line numberDiff line numberDiff line change
@@ -3526,6 +3526,34 @@ def test_single_env_spec():
35263526
assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape))
35273527

35283528

3529+
def test_auto_spec():
3530+
env = CountingEnv()
3531+
td = env.reset()
3532+
3533+
policy = lambda td, action_spec=env.full_action_spec.clone(): td.update(
3534+
action_spec.rand()
3535+
)
3536+
3537+
env.full_observation_spec = Composite(
3538+
shape=env.full_observation_spec.shape, device=env.full_observation_spec.device
3539+
)
3540+
env.full_action_spec = Composite(
3541+
shape=env.full_action_spec.shape, device=env.full_action_spec.device
3542+
)
3543+
env.full_reward_spec = Composite(
3544+
shape=env.full_reward_spec.shape, device=env.full_reward_spec.device
3545+
)
3546+
env.full_done_spec = Composite(
3547+
shape=env.full_done_spec.shape, device=env.full_done_spec.device
3548+
)
3549+
env.full_state_spec = Composite(
3550+
shape=env.full_state_spec.shape, device=env.full_state_spec.device
3551+
)
3552+
env._action_keys = ["action"]
3553+
env.auto_specs_(policy, tensordict=td.copy())
3554+
env.check_env_specs(tensordict=td.copy())
3555+
3556+
35293557
if __name__ == "__main__":
35303558
args, unknown = argparse.ArgumentParser().parse_known_args()
35313559
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/test_specs.py

+9
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,15 @@ def test_getitem(self, shape, is_complete, device, dtype):
412412
with pytest.raises(KeyError):
413413
_ = ts["UNK"]
414414

415+
def test_setitem_newshape(self, shape, is_complete, device, dtype):
416+
ts = self._composite_spec(shape, is_complete, device, dtype)
417+
new_spec = ts.clone()
418+
new_spec.shape = torch.Size(())
419+
new_spec.clear_device_()
420+
ts["new_spec"] = new_spec
421+
assert ts["new_spec"].shape == ts.shape
422+
assert ts["new_spec"].device == ts.device
423+
415424
def test_setitem_forbidden_keys(self, shape, is_complete, device, dtype):
416425
ts = self._composite_spec(shape, is_complete, device, dtype)
417426
for key in {"shape", "device", "dtype", "space"}:

torchrl/data/tensor_specs.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -4372,11 +4372,20 @@ def set(self, name, spec):
43724372
if spec is not None:
43734373
shape = spec.shape
43744374
if shape[: self.ndim] != self.shape:
4375-
raise ValueError(
4376-
"The shape of the spec and the Composite mismatch: the first "
4377-
f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and "
4378-
f"Composite.shape={self.shape}."
4379-
)
4375+
if (
4376+
isinstance(spec, Composite)
4377+
and spec.ndim < self.ndim
4378+
and self.shape[: spec.ndim] == spec.shape
4379+
):
4380+
# Try to set the composite shape
4381+
spec = spec.clone()
4382+
spec.shape = self.shape
4383+
else:
4384+
raise ValueError(
4385+
"The shape of the spec and the Composite mismatch: the first "
4386+
f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and "
4387+
f"Composite.shape={self.shape}."
4388+
)
43804389
self._specs[name] = spec
43814390

43824391
def __init__(
@@ -4448,6 +4457,8 @@ def clear_device_(self):
44484457
"""Clears the device of the Composite."""
44494458
self._device = None
44504459
for spec in self._specs.values():
4460+
if spec is None:
4461+
continue
44514462
spec.clear_device_()
44524463
return self
44534464

@@ -4530,6 +4541,10 @@ def __setitem__(self, key, value):
45304541
and value.device != self.device
45314542
):
45324543
if isinstance(value, Composite) and value.device is None:
4544+
# We make a clone not to mess up the spec that was provided.
4545+
# in set() we do the same for shape - these two ops should be grouped.
4546+
# we don't care about the overhead of cloning twice though because in theory
4547+
# we don't set specs often.
45334548
value = value.clone().to(self.device)
45344549
else:
45354550
raise RuntimeError(

0 commit comments

Comments
 (0)