Skip to content

Commit 0130c5d

Browse files
committed
[BugFix] Ensure that Composite.set returns self as TensorDict does
ghstack-source-id: a36213f34d3af93d5cacbf10f2a60fe3a874a9a6 Pull Request resolved: #2784
1 parent 74f6075 commit 0130c5d

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

torchrl/data/tensor_specs.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -4665,7 +4665,11 @@ def separates(self, *keys: NestedKey, default: Any = None) -> Composite:
46654665
out[key] = result
46664666
return out
46674667

4668-
def set(self, name, spec):
4668+
def set(self, name: str, spec: TensorSpec) -> Composite:
4669+
"""Sets a spec in the Composite spec."""
4670+
if not isinstance(name, str):
4671+
self[name] = spec
4672+
return self
46694673
if self.locked:
46704674
raise RuntimeError("Cannot modify a locked Composite.")
46714675
if spec is not None and self.device is not None and spec.device != self.device:
@@ -4698,6 +4702,7 @@ def set(self, name, spec):
46984702
f"Composite.shape={self.shape}."
46994703
)
47004704
self._specs[name] = spec
4705+
return self
47014706

47024707
def __init__(
47034708
self, *args, shape: torch.Size = None, device: torch.device = None, **kwargs
@@ -5733,9 +5738,10 @@ def ndim(self):
57335738
def ndimension(self):
57345739
return len(self.shape)
57355740

5736-
def set(self, name, spec):
5741+
def set(self, name: str, spec: TensorSpec) -> StackedComposite:
57375742
for sub_spec, sub_item in zip(self._specs, spec.unbind(self.dim)):
57385743
sub_spec[name] = sub_item
5744+
return self
57395745

57405746
@property
57415747
def shape(self):

0 commit comments

Comments
 (0)