From bcc5f565a5a6b3a793274e7d65ec0c2be51191dc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 28 Feb 2025 14:38:19 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/data/tensor_specs.py | 284 ++++++++++++++++++++--------------- 1 file changed, 167 insertions(+), 117 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 1e069eb93db..a5473c26e67 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -39,6 +39,7 @@ LazyStackedTensorDict, NonTensorData, NonTensorStack, + set_capture_non_tensor_stack, TensorDict, TensorDictBase, unravel_key, @@ -112,10 +113,10 @@ class _NoDefault(enum.IntEnum): def _default_dtype_and_device( - dtype: Union[None, torch.dtype], - device: Union[None, str, int, torch.device], + dtype: None | torch.dtype, + device: None | str | int | torch.device, allow_none_device: bool = False, -) -> Tuple[torch.dtype, torch.device | None]: +) -> tuple[torch.dtype, torch.device | None]: if dtype is None: dtype = torch.get_default_dtype() if device is not None: @@ -159,7 +160,7 @@ def _validate_iterable( ) -def _slice_indexing(shape: list[int], idx: slice) -> List[int]: +def _slice_indexing(shape: list[int], idx: slice) -> list[int]: """Given an input shape and a slice index, returns the new indexed shape. Args: @@ -206,8 +207,8 @@ def _slice_indexing(shape: list[int], idx: slice) -> List[int]: def _shape_indexing( - shape: Union[list[int], torch.Size, Tuple[int]], idx: SHAPE_INDEX_TYPING -) -> List[int]: + shape: list[int] | torch.Size | tuple[int], idx: SHAPE_INDEX_TYPING +) -> list[int]: """Given an input shape and an index, returns the size of the resulting indexed spec. This function includes indexing checks and may raise IndexErrors. @@ -373,7 +374,7 @@ class Box: def __iter__(self): raise NotImplementedError - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ContinuousBox: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> ContinuousBox: raise NotImplementedError def __repr__(self): @@ -430,7 +431,7 @@ def __iter__(self): yield self.low yield self.high - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ContinuousBox: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> ContinuousBox: return self.__class__(self.low.to(dest), self.high.to(dest)) def clone(self) -> ContinuousBox: @@ -486,7 +487,7 @@ def __post_init__(self): # We want to make sure we're working with a regular integer self.__dict__["n"] = int(self.n) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CategoricalBox: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> CategoricalBox: return deepcopy(self) def __repr__(self): @@ -503,14 +504,13 @@ class DiscreteBox(CategoricalBox): class BoxList(Box): """A box of discrete values.""" - boxes: List + boxes: list - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> BoxList: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> BoxList: return BoxList([box.to(dest) for box in self.boxes]) def __iter__(self): - for elt in self.boxes: - yield elt + yield from self.boxes def __repr__(self): return f"{self.__class__.__name__}(boxes={self.boxes})" @@ -532,7 +532,7 @@ class BinaryBox(Box): n: int - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> ContinuousBox: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> ContinuousBox: return deepcopy(self) def __repr__(self): @@ -569,7 +569,7 @@ class TensorSpec(metaclass=abc.ABCMeta): """ shape: torch.Size - space: Union[None, Box] + space: None | Box device: torch.device | None = None dtype: torch.dtype = torch.float domain: str = "" @@ -837,7 +837,7 @@ def reshape(self, *shape) -> T: def _reshape(self, shape: torch.Size) -> T: ... - def unflatten(self, dim: int, sizes: Tuple[int]) -> T: + def unflatten(self, dim: int, sizes: tuple[int]) -> T: """Unflattens a ``TensorSpec``. Check :func:`~torch.unflatten` for more information on this method. @@ -845,7 +845,7 @@ def unflatten(self, dim: int, sizes: Tuple[int]) -> T: """ return self._unflatten(dim, sizes) - def _unflatten(self, dim: int, sizes: Tuple[int]) -> T: + def _unflatten(self, dim: int, sizes: tuple[int]) -> T: shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape return self._reshape(shape) @@ -1021,7 +1021,7 @@ def ones(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: return self.one(shape=shape) @abc.abstractmethod - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> "TensorSpec": + def to(self, dest: torch.dtype | DEVICE_TYPING) -> TensorSpec: """Casts a TensorSpec to a device or a dtype. Returns the same spec if no change is made. @@ -1039,7 +1039,7 @@ def cuda(self, device=None): return self.to(f"cuda:{device}") @abc.abstractmethod - def clone(self) -> "TensorSpec": + def clone(self) -> TensorSpec: """Creates a copy of the TensorSpec.""" ... @@ -1060,8 +1060,8 @@ def __torch_function__( cls, func: Callable, types, - args: Tuple = (), - kwargs: Optional[dict] = None, + args: tuple = (), + kwargs: dict | None = None, ) -> Callable: if kwargs is None: kwargs = {} @@ -1081,7 +1081,7 @@ def unbind(self, dim: int = 0): class _LazyStackedMixin(Generic[T]): - def __init__(self, *specs: Tuple[T, ...], dim: int) -> None: + def __init__(self, *specs: tuple[T, ...], dim: int) -> None: self._specs = list(specs) self.dim = dim if self.dim < 0: @@ -1221,7 +1221,7 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase: ) return torch.nested.nested_tensor(samples) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> T: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> T: if dest is None: return self return torch.stack([spec.to(dest) for spec in self._specs], self.dim) @@ -1496,7 +1496,7 @@ def _project(self, val: TensorDictBase) -> TensorDictBase: raise NOT_IMPLEMENTED_ERROR def encode( - self, val: Union[np.ndarray, torch.Tensor], *, ignore_device=False + self, val: np.ndarray | torch.Tensor, *, ignore_device=False ) -> torch.Tensor: if self.dim != 0 and not isinstance(val, tuple): val = val.unbind(self.dim) @@ -1574,9 +1574,9 @@ class OneHot(TensorSpec): def __init__( self, n: int, - shape: Optional[torch.Size] = None, - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[Union[str, torch.dtype]] = torch.bool, + shape: torch.Size | None = None, + device: DEVICE_TYPING | None = None, + dtype: str | torch.dtype | None = torch.bool, use_register: bool = False, mask: torch.Tensor | None = None, ): @@ -1635,7 +1635,7 @@ def update_mask(self, mask): raise ValueError("Only boolean masks are accepted.") self.mask = mask - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> OneHot: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> OneHot: if dest is None: return self if isinstance(dest, torch.dtype): @@ -1821,8 +1821,8 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor: # noqa: F811 def encode( self, - val: Union[np.ndarray, torch.Tensor], - space: Optional[CategoricalBox] = None, + val: np.ndarray | torch.Tensor, + space: CategoricalBox | None = None, *, ignore_device: bool = False, ) -> torch.Tensor: @@ -2082,11 +2082,11 @@ class Bounded(TensorSpec, metaclass=_BoundedMeta): def __init__( self, - low: Union[float, torch.Tensor, np.ndarray] = None, - high: Union[float, torch.Tensor, np.ndarray] = None, - shape: Optional[Union[torch.Size, int]] = None, - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[Union[torch.dtype, str]] = None, + low: float | torch.Tensor | np.ndarray = None, + high: float | torch.Tensor | np.ndarray = None, + shape: torch.Size | int | None = None, + device: DEVICE_TYPING | None = None, + dtype: torch.dtype | str | None = None, **kwargs, ): if "maximum" in kwargs: @@ -2400,7 +2400,7 @@ def is_in(self, val: torch.Tensor) -> bool: return False raise err - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Bounded: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> Bounded: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -2452,11 +2452,11 @@ class BoundedContinuous(Bounded, metaclass=_BoundedMeta): def __init__( self, - low: Union[float, torch.Tensor, np.ndarray] = None, - high: Union[float, torch.Tensor, np.ndarray] = None, - shape: Optional[Union[torch.Size, int]] = None, - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[Union[torch.dtype, str]] = None, + low: float | torch.Tensor | np.ndarray = None, + high: float | torch.Tensor | np.ndarray = None, + shape: torch.Size | int | None = None, + device: DEVICE_TYPING | None = None, + dtype: torch.dtype | str | None = None, domain: str = "continuous", ): super().__init__( @@ -2469,11 +2469,11 @@ class BoundedDiscrete(Bounded, metaclass=_BoundedMeta): def __init__( self, - low: Union[float, torch.Tensor, np.ndarray] = None, - high: Union[float, torch.Tensor, np.ndarray] = None, - shape: Optional[Union[torch.Size, int]] = None, - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[Union[torch.dtype, str]] = None, + low: float | torch.Tensor | np.ndarray = None, + high: float | torch.Tensor | np.ndarray = None, + shape: torch.Size | int | None = None, + device: DEVICE_TYPING | None = None, + dtype: torch.dtype | str | None = None, domain: str = "discrete", ): super().__init__( @@ -2503,23 +2503,50 @@ def _is_nested_list(index, notuple=False): class NonTensor(TensorSpec): """A spec for non-tensor data. - This spec has a shae, device and dtype like :class:`~tensordict.NonTensorData`. + The `NonTensor` class is designed to handle specifications for data that do not conform to standard tensor + structures. + It maintains attributes such as shape, and device similar to the `NonTensorData` class. + The dtype is optional and should in practice be left to `None` in most cases. + Methods like `rand`, `zero`, and `one` will return a `NonTensorData` object with a `None` data value. - :meth:`.rand` will return a :class:`~tensordict.NonTensorData` object with `None` data value. - (same will go for :meth:`.zero` and :meth:`.one`). + .. warning:: The default shape of `NonTensor` is `(1,)`. - .. note:: The default shape of `NonTensor` is `(1,)`. + Args: + shape (Union[torch.Size, int], optional): The shape of the non-tensor data. Defaults to `(1,)`. + device (Optional[DEVICE_TYPING], optional): The device on which the data is stored. Defaults to `None`. + dtype (torch.dtype | None, optional): The data type of the non-tensor data. Defaults to `None`. + example_data (Any, optional): An example of the data that this spec represents. This example is used as a + template when generating new data with the `rand`, `zero`, and `one` methods. + batched (bool, optional): Indicates whether the data is batched. If `True`, the `rand`, `zero`, and `one` methods + will generate data with an additional batch dimension, stacking copies of the `example_data` across this dimension. + Defaults to `False`. + **kwargs: Additional keyword arguments passed to the parent class. + + .. seealso:: :class:`~torchrl.data.Choice` which allows to randomly choose among different specs when calling + `rand`. + Examples: + >>> from torchrl.data import NonTensor + >>> spec = NonTensor(example_data="a string", batched=False, shape=(3,)) + >>> spec.rand() + NonTensorData(data=a string, batch_size=torch.Size([3]), device=None) + >>> spec = NonTensor(example_data="a string", batched=True, shape=(3,)) + >>> spec.rand() + NonTensorStack( + ['a string', 'a string', 'a string'], + batch_size=torch.Size([3]), + device=None) """ example_data: Any = None def __init__( self, - shape: Union[torch.Size, int] = _DEFAULT_SHAPE, - device: Optional[DEVICE_TYPING] = None, + shape: torch.Size | int = _DEFAULT_SHAPE, + device: DEVICE_TYPING | None = None, dtype: torch.dtype | None = None, example_data: Any = None, + batched: bool = False, **kwargs, ): if isinstance(shape, int): @@ -2530,6 +2557,20 @@ def __init__( shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs ) self.example_data = example_data + self.batched = batched + + def __repr__(self): + shape_str = indent("shape=" + str(self.shape), " " * 4) + space_str = indent("space=" + str(self.space), " " * 4) + device_str = indent("device=" + str(self.device), " " * 4) + dtype_str = indent("dtype=" + str(self.dtype), " " * 4) + domain_str = indent("domain=" + str(self.domain), " " * 4) + example_str = indent("example_data=" + str(self.example_data), " " * 4) + sub_string = ",\n".join( + [shape_str, space_str, device_str, dtype_str, domain_str, example_str] + ) + string = f"{self.__class__.__name__}(\n{sub_string})" + return string def __eq__(self, other): eq = super().__eq__(other) @@ -2550,7 +2591,7 @@ def cardinality(self) -> Any: def enumerate(self, use_mask: bool = False) -> Any: raise NotImplementedError("Cannot enumerate a NonTensor spec.") - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> NonTensor: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -2566,6 +2607,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor: device=dest_device, dtype=None, example_data=self.example_data, + batched=self.batched, ) def clone(self) -> NonTensor: @@ -2574,11 +2616,24 @@ def clone(self) -> NonTensor: device=self.device, dtype=self.dtype, example_data=self.example_data, + batched=self.batched, ) def rand(self, shape=None): if shape is None: shape = () + if self.batched: + with set_capture_non_tensor_stack(False): + val = NonTensorData( + data=self.example_data, + batch_size=(), + device=self.device, + ) + shape = (*shape, *self._safe_shape) + if shape: + for i in shape: + val = torch.stack([val.copy() for _ in range(i)], -1) + return val return NonTensorData( data=self.example_data, batch_size=(*shape, *self._safe_shape), @@ -2586,22 +2641,10 @@ def rand(self, shape=None): ) def zero(self, shape=None): - if shape is None: - shape = () - return NonTensorData( - data=self.example_data, - batch_size=(*shape, *self._safe_shape), - device=self.device, - ) + return self.rand(shape=shape) def one(self, shape=None): - if shape is None: - shape = () - return NonTensorData( - data=self.example_data, - batch_size=(*shape, *self._safe_shape), - device=self.device, - ) + return self.rand(shape=shape) def is_in(self, val: Any) -> bool: if not isinstance(val, torch.Tensor) and not is_tensor_collection(val): @@ -2613,6 +2656,7 @@ def is_in(self, val: Any) -> bool: # We relax constrains on device as they're hard to enforce for non-tensor # tensordicts and pointless # and val.device == self.device + # TODO: do we want this? and val.dtype == self.dtype ) @@ -2628,17 +2672,23 @@ def expand(self, *shape): f"The last elements of the expanded shape must match the current one. Got shape={shape} while self.shape={self.shape}." ) return self.__class__( - shape=shape, device=self.device, dtype=None, example_data=self.example_data + shape=shape, + device=self.device, + dtype=None, + example_data=self.example_data, + batched=self.batched, ) def unsqueeze(self, dim: int) -> NonTensor: unsq = super().unsqueeze(dim=dim) unsq.example_data = self.example_data + unsq.batched = self.batched return unsq def squeeze(self, dim: int | None = None) -> NonTensor: sq = super().squeeze(dim=dim) sq.example_data = self.example_data + sq.batched = self.batched return sq def _reshape(self, shape): @@ -2647,6 +2697,7 @@ def _reshape(self, shape): device=self.device, dtype=self.dtype, example_data=self.example_data, + batched=self.batched, ) def _unflatten(self, dim, sizes): @@ -2656,6 +2707,7 @@ def _unflatten(self, dim, sizes): device=self.device, dtype=self.dtype, example_data=self.example_data, + batched=self.batched, ) def __getitem__(self, idx: SHAPE_INDEX_TYPING): @@ -2666,6 +2718,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): device=self.device, dtype=self.dtype, example_data=self.example_data, + batched=self.batched, ) def unbind(self, dim: int = 0): @@ -2683,6 +2736,7 @@ def unbind(self, dim: int = 0): device=self.device, dtype=self.dtype, example_data=self.example_data, + batched=self.batched, ) for i in range(self.shape[dim]) ) @@ -2774,9 +2828,9 @@ class Unbounded(TensorSpec, metaclass=_UnboundedMeta): def __init__( self, - shape: Union[torch.Size, int] = _DEFAULT_SHAPE, - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[Union[str, torch.dtype]] = None, + shape: torch.Size | int = _DEFAULT_SHAPE, + device: DEVICE_TYPING | None = None, + dtype: str | torch.dtype | None = None, **kwargs, ): if isinstance(shape, int): @@ -2822,7 +2876,7 @@ def index( ) -> torch.Tensor | TensorDictBase: raise NotImplementedError("`index` is not implemented for Unbounded specs.") - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Unbounded: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> Unbounded: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -2841,7 +2895,7 @@ def clone(self) -> Unbounded: def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: shape = _size([]) - shape = [*shape, *self.shape] + shape = [*shape, *self._safe_shape] if self.dtype.is_floating_point: return torch.randn(shape, device=self.device, dtype=self.dtype) return torch.empty(shape, device=self.device, dtype=self.dtype).random_() @@ -2946,9 +3000,9 @@ class UnboundedDiscrete(Unbounded): def __init__( self, - shape: Union[torch.Size, int] = _DEFAULT_SHAPE, - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[Union[str, torch.dtype]] = torch.int64, + shape: torch.Size | int = _DEFAULT_SHAPE, + device: DEVICE_TYPING | None = None, + dtype: str | torch.dtype | None = torch.int64, **kwargs, ): super().__init__(shape=shape, device=device, dtype=dtype, **kwargs) @@ -2994,7 +3048,7 @@ class MultiOneHot(OneHot): def __init__( self, nvec: Sequence[int], - shape: Optional[torch.Size] = None, + shape: torch.Size | None = None, device=None, dtype=torch.bool, use_register=False, @@ -3071,7 +3125,7 @@ def update_mask(self, mask): raise ValueError("Only boolean masks are accepted.") self.mask = mask - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> MultiOneHot: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> MultiOneHot: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -3119,7 +3173,7 @@ def __eq__(self, other): and mask_equal ) - def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: + def rand(self, shape: torch.Size | None = None) -> torch.Tensor: if shape is None: shape = self.shape[:-1] else: @@ -3160,7 +3214,7 @@ def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: return torch.cat(out, -1) def encode( - self, val: Union[np.ndarray, torch.Tensor], *, ignore_device: bool = False + self, val: np.ndarray | torch.Tensor, *, ignore_device: bool = False ) -> torch.Tensor: if not isinstance(val, torch.Tensor): if not ignore_device: @@ -3174,12 +3228,10 @@ def encode( raise RuntimeError( f"value {v} is greater than the allowed max {space.n}" ) - x.append( - super(MultiOneHot, self).encode(v, space, ignore_device=ignore_device) - ) + x.append(super().encode(v, space, ignore_device=ignore_device)) return torch.cat(x, -1).reshape(self.shape) - def _split(self, val: torch.Tensor) -> Optional[torch.Tensor]: + def _split(self, val: torch.Tensor) -> torch.Tensor | None: split_sizes = [space.n for space in self.space] if val.ndim < 1 or val.shape[-1] != sum(split_sizes): return None @@ -3806,7 +3858,7 @@ def unbind(self, dim: int = 0): for i in range(self.shape[dim]) ) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Categorical: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> Categorical: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -3862,17 +3914,15 @@ class Choice(TensorSpec): def __init__( self, - choices: List[TensorSpec | NonTensorData | NonTensorStack], + choices: list[TensorSpec | NonTensorData | NonTensorStack], ): if not isinstance(choices, list): raise TypeError("'choices' must be a list") if not isinstance(choices[0], (TensorSpec, NonTensorData, NonTensorStack)): raise TypeError( - ( - "Each choice must be either a TensorSpec, NonTensorData, or " - f"NonTensorStack, but got {type(choices[0])}" - ) + "Each choice must be either a TensorSpec, NonTensorData, or " + f"NonTensorStack, but got {type(choices[0])}" ) if not all([isinstance(choice, type(choices[0])) for choice in choices[1:]]): @@ -3941,7 +3991,7 @@ def cardinality(self) -> int: .item() ) - def enumerate(self, use_mask: bool = False) -> List[Any]: + def enumerate(self, use_mask: bool = False) -> list[Any]: return [s for choice in self._choices for s in choice.enumerate()] def _project( @@ -3970,7 +4020,7 @@ def num_choices(self): """Number of choices for the spec.""" return len(self._choices) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Choice: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> Choice: return self.__class__([choice.to(dest) for choice in self._choices]) def __eq__(self, other): @@ -4028,9 +4078,9 @@ class Binary(Categorical): def __init__( self, n: int | None = None, - shape: Optional[torch.Size] = None, - device: Optional[DEVICE_TYPING] = None, - dtype: Union[str, torch.dtype] = torch.int8, + shape: torch.Size | None = None, + device: DEVICE_TYPING | None = None, + dtype: str | torch.dtype = torch.int8, ): if n is None and shape is None: raise TypeError("Must provide either n or shape.") @@ -4107,7 +4157,7 @@ def unbind(self, dim: int = 0): for i in range(self.shape[dim]) ) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Binary: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> Binary: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -4183,10 +4233,10 @@ class MultiCategorical(Categorical): def __init__( self, - nvec: Union[Sequence[int], torch.Tensor, int], - shape: Optional[torch.Size] = None, - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[Union[str, torch.dtype]] = torch.int64, + nvec: Sequence[int] | torch.Tensor | int, + shape: torch.Size | None = None, + device: DEVICE_TYPING | None = None, + dtype: str | torch.dtype | None = torch.int64, mask: torch.Tensor | None = None, remove_singleton: bool = True, ): @@ -4282,7 +4332,7 @@ def update_mask(self, mask): raise ValueError("Only boolean masks are accepted.") self.mask = mask - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> MultiCategorical: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> MultiCategorical: if isinstance(dest, torch.dtype): dest_dtype = dest dest_device = self.device @@ -4347,7 +4397,7 @@ def _rand(self, space: Box, shape: torch.Size, i: int): ) return torch.stack(x, -1) - def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: + def rand(self, shape: torch.Size | None = None) -> torch.Tensor: if self.mask is not None: splits = self._split_self() return torch.stack([split.rand(shape) for split in splits], -1) @@ -4436,7 +4486,7 @@ def is_in(self, val: torch.Tensor) -> bool: def to_one_hot( self, val: torch.Tensor, safe: bool = None - ) -> Union[MultiOneHot, torch.Tensor]: + ) -> MultiOneHot | torch.Tensor: """Encodes a discrete tensor from the spec domain into its one-hot correspondent. Args: @@ -5023,8 +5073,8 @@ def __delitem__(self, key: NestedKey) -> None: del self._specs[key] def encode( - self, vals: Dict[str, Any], *, ignore_device: bool = False - ) -> Dict[str, torch.Tensor]: + self, vals: dict[str, Any], *, ignore_device: bool = False + ) -> dict[str, torch.Tensor]: if isinstance(vals, TensorDict): out = vals.empty() # create and empty tensordict similar to vals elif self.data_cls is not None: @@ -5059,8 +5109,8 @@ def __repr__(self) -> str: def type_check( self, - value: Union[torch.Tensor, TensorDictBase], - selected_keys: Union[str, Optional[Sequence[str]]] = None, + value: torch.Tensor | TensorDictBase, + selected_keys: str | Sequence[str] | None = None, ): if isinstance(value, torch.Tensor) and isinstance(selected_keys, str): value = {selected_keys: value} @@ -5072,7 +5122,7 @@ def type_check( ): self._specs[_key].type_check(value[_key], _key) - def is_in(self, val: Union[dict, TensorDictBase]) -> bool: + def is_in(self, val: dict | TensorDictBase) -> bool: # TODO: make warnings for these # if val.device != self.device: # print(val.device, self.device) @@ -5232,7 +5282,7 @@ def _unflatten(self, dim, sizes): def __len__(self): return len(self.keys()) - def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> Composite: + def to(self, dest: torch.dtype | DEVICE_TYPING) -> Composite: if dest is None: return self if not isinstance(dest, (str, int, torch.device)): @@ -5367,7 +5417,7 @@ def __eq__(self, other): and other.data_cls == self.data_cls ) - def update(self, dict_or_spec: Union[Composite, Dict[str, TensorSpec]]) -> None: + def update(self, dict_or_spec: Composite | dict[str, TensorSpec]) -> None: for key, item in dict_or_spec.items(): if key in self.keys(True) and isinstance(self[key], Composite): self[key].update(item) @@ -5836,8 +5886,8 @@ def project(self, val: TensorDictBase) -> TensorDictBase: def type_check( self, - value: Union[torch.Tensor, TensorDictBase], - selected_keys: Union[NestedKey, Optional[Sequence[NestedKey]]] = None, + value: torch.Tensor | TensorDictBase, + selected_keys: NestedKey | Sequence[NestedKey] | None = None, ): if selected_keys is None: if isinstance(value, torch.Tensor): @@ -6010,8 +6060,8 @@ def empty(self): ) def encode( - self, vals: Dict[str, Any], ignore_device: bool = False - ) -> Dict[str, torch.Tensor]: + self, vals: dict[str, Any], ignore_device: bool = False + ) -> dict[str, torch.Tensor]: raise NOT_IMPLEMENTED_ERROR def zero(self, shape: torch.Size = None) -> TensorDictBase: @@ -6482,9 +6532,9 @@ class UnboundedDiscreteTensorSpec( def __init__( self, - shape: Union[torch.Size, int] = _DEFAULT_SHAPE, - device: Optional[DEVICE_TYPING] = None, - dtype: Optional[Union[str, torch.dtype]] = torch.int64, + shape: torch.Size | int = _DEFAULT_SHAPE, + device: DEVICE_TYPING | None = None, + dtype: str | torch.dtype | None = torch.int64, **kwargs, ): super().__init__(shape=shape, device=device, dtype=dtype, **kwargs)