diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 8c916b5c3d7..4e0c4d00967 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -10,11 +10,11 @@ import torch import torch.nn as nn -from tensordict import TensorDict, TensorDictBase +from tensordict import tensorclass, TensorDict, TensorDictBase from tensordict.nn import TensorDictModuleBase from tensordict.utils import expand_right, NestedKey -from torchrl.data.tensor_specs import ( +from torchrl.data import ( Binary, Bounded, Categorical, @@ -2356,3 +2356,42 @@ def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDict: def _set_seed(self, seed: Optional[int]): ... + + +@tensorclass() +class TC: + field0: str + field1: torch.Tensor + + +class EnvWithTensorClass(CountingEnv): + tc_cls = TC + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.observation_spec["tc"] = Composite( + field0=NonTensor(example_data="an observation!", shape=self.batch_size), + field1=Unbounded(shape=self.batch_size), + shape=self.batch_size, + data_cls=TC, + ) + + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + td = super()._reset(tensordict, **kwargs) + td["tc"] = TC("0", torch.zeros(self.batch_size)) + return td + + def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + td = super()._step(tensordict, **kwargs) + default = TC("0", 0) + f0 = tensordict.get("tc", default).field0 + if f0 is None: + f0 = "0" + f1 = tensordict.get("tc", default).field1 + if f1 is None: + f1 = torch.zeros(self.batch_size) + td["tc"] = TC( + str(int(f0) + 1), + f1 + 1, + ) + return td diff --git a/test/test_env.py b/test/test_env.py index ad02467d6ab..74b941de580 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -127,6 +127,7 @@ EnvThatDoesNothing, EnvWithDynamicSpec, EnvWithMetadata, + EnvWithTensorClass, HeterogeneousCountingEnv, HeterogeneousCountingEnvPolicy, MockBatchedLockedEnv, @@ -166,6 +167,7 @@ EnvThatDoesNothing, EnvWithDynamicSpec, EnvWithMetadata, + EnvWithTensorClass, HeterogeneousCountingEnv, HeterogeneousCountingEnvPolicy, MockBatchedLockedEnv, @@ -3708,6 +3710,29 @@ def test_str2str_rb_slicesampler(self): else: raise RuntimeError("Failed to sample both trajs") + def test_env_with_tensorclass(self): + env = EnvWithTensorClass() + env.check_env_specs() + r = env.reset() + for _ in range(3): + assert isinstance(r["tc"], env.tc_cls) + a = env.rand_action(r) + s = env.step(a) + assert isinstance(s["tc"], env.tc_cls) + r = env.step_mdp(s) + + @pytest.mark.parametrize("cls", [SerialEnv, ParallelEnv]) + def test_env_with_tensorclass_batched(self, cls): + env = cls(2, EnvWithTensorClass) + env.check_env_specs() + r = env.reset() + for _ in range(3): + assert isinstance(r["tc"], EnvWithTensorClass.tc_cls) + a = env.rand_action(r) + s = env.step(a) + assert isinstance(s["tc"], EnvWithTensorClass.tc_cls) + r = env.step_mdp(s) + # fen strings for board positions generated with: # https://lichess.org/editor diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 58ae156bde5..a550ac4e42b 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -2538,6 +2538,8 @@ def one(self, shape=None): ) def is_in(self, val: Any) -> bool: + if not isinstance(val, torch.Tensor) and not is_tensor_collection(val): + return True shape = torch.broadcast_shapes(self._safe_shape, val.shape) return ( is_non_tensor(val) @@ -4487,6 +4489,8 @@ class Composite(TensorSpec): to ``None``. shape (torch.Size): the leading shape of all the leaves. Equivalent to the batch-size of the corresponding tensordicts. + data_cls (type, optional): the tensordict subclass (TensorDict, TensorClass, tensorclass...) that should be + enforced in the env. Defaults to ``None``. Examples: >>> pixels_spec = Bounded( @@ -4556,6 +4560,48 @@ def __new__(cls, *args, **kwargs): cls._is_locked = False return super().__new__(cls) + def __init__( + self, + *args, + shape: torch.Size = None, + device: torch.device = None, + data_cls: type | None = None, + **kwargs, + ): + # For compatibility with TensorDict + batch_size = kwargs.pop("batch_size", None) + if batch_size is not None: + if shape is not None: + raise TypeError("Cannot specify both batch_size and shape.") + shape = batch_size + + if shape is None: + shape = _size(()) + self._shape = _size(shape) + self._specs = {} + + _device = ( + _make_ordinal_device(torch.device(device)) if device is not None else device + ) + self._device = _device + if len(args): + if len(args) > 1: + raise RuntimeError( + "Got multiple arguments, when at most one is expected for Composite." + ) + argdict = args[0] + if not isinstance(argdict, (dict, Composite)): + raise RuntimeError( + f"Expected a dictionary of specs, but got an argument of type {type(argdict)}." + ) + for k, item in argdict.items(): + if isinstance(item, dict): + item = Composite(item, shape=shape, device=_device) + self[k] = item + for k, item in kwargs.items(): + self[k] = item + self.data_cls = data_cls + @property def batch_size(self): return self._shape @@ -4704,42 +4750,6 @@ def set(self, name: str, spec: TensorSpec) -> Composite: self._specs[name] = spec return self - def __init__( - self, *args, shape: torch.Size = None, device: torch.device = None, **kwargs - ): - # For compatibility with TensorDict - batch_size = kwargs.pop("batch_size", None) - if batch_size is not None: - if shape is not None: - raise TypeError("Cannot specify both batch_size and shape.") - shape = batch_size - - if shape is None: - shape = _size(()) - self._shape = _size(shape) - self._specs = {} - - _device = ( - _make_ordinal_device(torch.device(device)) if device is not None else device - ) - self._device = _device - if len(args): - if len(args) > 1: - raise RuntimeError( - "Got multiple arguments, when at most one is expected for Composite." - ) - argdict = args[0] - if not isinstance(argdict, (dict, Composite)): - raise RuntimeError( - f"Expected a dictionary of specs, but got an argument of type {type(argdict)}." - ) - for k, item in argdict.items(): - if isinstance(item, dict): - item = Composite(item, shape=shape, device=_device) - self[k] = item - for k, item in kwargs.items(): - self[k] = item - @property def device(self) -> DEVICE_TYPING: return self._device @@ -4868,6 +4878,8 @@ def encode( ) -> Dict[str, torch.Tensor]: if isinstance(vals, TensorDict): out = vals.empty() # create and empty tensordict similar to vals + elif self.data_cls is not None: + out = {} else: out = TensorDict._new_unsafe({}, _size([])) for key, item in vals.items(): @@ -4885,6 +4897,8 @@ def encode( raise RuntimeError( f"Encoding key {key} raised a RuntimeError. Scroll up to know more." ) from err + if self.data_cls is not None: + return self.data_cls.from_dict(out) return out def __repr__(self) -> str: @@ -4910,6 +4924,14 @@ def type_check( self._specs[_key].type_check(value[_key], _key) def is_in(self, val: Union[dict, TensorDictBase]) -> bool: + # TODO: make warnings for these + # if val.device != self.device: + # print(val.device, self.device) + # return False + # if val.shape[-self.ndim:] != self.shape: + # return False + if self.data_cls is not None and type(val) != self.data_cls: + return False for key, item in self._specs.items(): if item is None or (isinstance(item, Composite) and item.is_empty()): continue @@ -4934,12 +4956,16 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase: for key, item in self.items(): if item is not None: _dict[key] = item.rand(shape) + if self.data_cls is None: + cls = TensorDict + else: + cls = self.data_cls # No need to run checks since we know Composite is compliant with # TensorDict requirements - return TensorDict( + return cls.from_dict( _dict, batch_size=_size([*shape, *_remove_neg_shapes(self.shape)]), - device=self._device, + device=self.device, ) def keys( @@ -5046,7 +5072,9 @@ def _reshape(self, shape): key: val.reshape((*shape, *val.shape[self.ndimension() :])) for key, val in self._specs.items() } - return Composite(_specs, shape=shape) + return self.__class__( + _specs, shape=shape, device=self.device, data_cls=self.data_cls + ) def _unflatten(self, dim, sizes): shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape @@ -5073,7 +5101,9 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Composite: kwargs[key] = value continue kwargs[key] = value.to(dest) - return self.__class__(**kwargs, device=_device, shape=self.shape) + return self.__class__( + **kwargs, device=_device, shape=self.shape, data_cls=self.data_cls + ) def clone(self) -> Composite: """Clones the Composite spec. @@ -5091,6 +5121,7 @@ def clone(self) -> Composite: }, device=device, shape=self.shape, + data_cls=self.data_cls, ) def cardinality(self) -> int: @@ -5112,6 +5143,10 @@ def enumerate(self) -> TensorDictBase: while self_without_batch.ndim: self_without_batch = self_without_batch[0] samples = {key: spec.enumerate() for key, spec in self_without_batch.items()} + if self.data_cls is not None: + cls = self.data_cls + else: + cls = TensorDict if samples: idx_rep = torch.meshgrid( *(torch.arange(s.shape[0]) for s in samples.values()), indexing="ij" @@ -5121,7 +5156,7 @@ def enumerate(self) -> TensorDictBase: key: sample[idx] for ((key, sample), idx) in zip(samples.items(), idx_rep) } - samples = TensorDict( + samples = cls.from_dict( samples, batch_size=idx_rep[0].shape[:1], device=self.device ) # Expand @@ -5129,7 +5164,7 @@ def enumerate(self) -> TensorDictBase: samples = samples.reshape(-1, *(1,) * self.ndim) samples = samples.expand(samples.shape[0], *self.shape) else: - samples = TensorDict(batch_size=self.shape, device=self.device) + samples = cls.from_dict({}, batch_size=self.shape, device=self.device) return samples def empty(self): @@ -5142,6 +5177,7 @@ def empty(self): {}, device=device, shape=self.shape, + data_cls=self.data_cls, ) def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: @@ -5154,13 +5190,19 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase: device = self.device except RuntimeError: device = self._device - return TensorDict( + + if self.data_cls is not None: + cls = self.data_cls + else: + cls = TensorDict + + return cls.from_dict( { key: self[key].zero(shape) for key in self.keys(True) if isinstance(key, str) and self[key] is not None }, - _size([*shape, *self._safe_shape]), + batch_size=_size([*shape, *self._safe_shape]), device=device, ) @@ -5171,6 +5213,7 @@ def __eq__(self, other): and self._device == other._device and set(self._specs.keys()) == set(other._specs.keys()) and all((self._specs[key] == spec) for (key, spec) in other._specs.items()) + and other.data_cls == self.data_cls ) def update(self, dict_or_spec: Union[Composite, Dict[str, TensorSpec]]) -> None: @@ -5220,6 +5263,7 @@ def expand(self, *shape): specs, shape=shape, device=device, + data_cls=self.data_cls, ) return out @@ -5237,10 +5281,11 @@ def squeeze(self, dim: int | None = None): except RuntimeError: device = self._device - return Composite( + return self.__class__( {key: value.squeeze(dim) for key, value in self.items()}, shape=shape, device=device, + data_cls=self.data_cls, ) if self.shape.count(1) == 0: @@ -5263,13 +5308,14 @@ def unsqueeze(self, dim: int): except RuntimeError: device = self._device - return Composite( + return self.__class__( { key: value.unsqueeze(dim) if value is not None else None for key, value in self.items() }, shape=shape, device=device, + data_cls=self.data_cls, ) def unbind(self, dim: int = 0): @@ -5287,6 +5333,7 @@ def unbind(self, dim: int = 0): {key: val[i] for key, val in unbound_vals.items()}, shape=shape, device=self.device, + data_cls=self.data_cls, ) for i in range(self.shape[dim]) ) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 7bc9b0c1a5a..184cf5892f0 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -2412,7 +2412,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): raise RuntimeError("called 'init' before step") i += 1 # No need to copy here since we don't write in-place - input = root_shared_tensordict + input = root_shared_tensordict.copy() if data: next_td_passthrough_keys = data.get("next_td_passthrough_keys") if next_td_passthrough_keys is not None: @@ -2423,7 +2423,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): if non_tensor_data is not None: input.update(non_tensor_data) - input = env.step(input.copy()) + input = env.step(input) next_td = input.get("next") next_shared_tensordict.update_(next_td, non_blocking=non_blocking) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 39b0faa9692..b4705030547 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -216,12 +216,12 @@ def _grab_and_place( val = data_in._get_str(key, NO_DEFAULT) if subdict is not None: val_out = data_out._get_str(key, None) - if val_out is None: - val_out = val.empty() + if val_out is None or val_out.batch_size != val.batch_size: + val_out = val.empty(batch_size=val.batch_size) if isinstance(val, LazyStackedTensorDict): - val = LazyStackedTensorDict( - *( + val = LazyStackedTensorDict.lazy_stack( + [ cls._grab_and_place( subdict, _val, @@ -232,8 +232,8 @@ def _grab_and_place( val.unbind(val.stack_dim), val_out.unbind(val_out.stack_dim), ) - ), - stack_dim=val.stack_dim, + ], + dim=val.stack_dim, ) else: val = cls._grab_and_place( @@ -302,8 +302,8 @@ def __call__(self, tensordict): ) if isinstance(next_td, LazyStackedTensorDict): if not isinstance(out, LazyStackedTensorDict): - out = LazyStackedTensorDict( - *out.unbind(next_td.stack_dim), stack_dim=next_td.stack_dim + out = LazyStackedTensorDict.lazy_stack( + list(out.unbind(next_td.stack_dim)), dim=next_td.stack_dim ) for _next_td, _out in zip(next_td.tensordicts, out.tensordicts): self._grab_and_place(