Skip to content

Commit 3da2750

Browse files
committed
[BugFix] NonTensor should not convert anything to numpy
ghstack-source-id: 7644f6c Pull Request resolved: #2771
1 parent 09e93c1 commit 3da2750

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

torchrl/data/tensor_specs.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -2619,6 +2619,19 @@ def unbind(self, dim: int = 0):
26192619
for i in range(self.shape[dim])
26202620
)
26212621

2622+
def to_numpy(
2623+
self, val: torch.Tensor | TensorDictBase, safe: bool = None
2624+
) -> np.ndarray | dict:
2625+
return val
2626+
2627+
def encode(
2628+
self,
2629+
val: np.ndarray | torch.Tensor | TensorDictBase,
2630+
*,
2631+
ignore_device: bool = False,
2632+
) -> torch.Tensor | TensorDictBase:
2633+
return val
2634+
26222635

26232636
class _UnboundedMeta(abc.ABCMeta):
26242637
def __call__(cls, *args, **kwargs):
@@ -4918,7 +4931,7 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase:
49184931
_dict[key] = item.rand(shape)
49194932
# No need to run checks since we know Composite is compliant with
49204933
# TensorDict requirements
4921-
return TensorDict._new_unsafe(
4934+
return TensorDict(
49224935
_dict,
49234936
batch_size=_size([*shape, *self.shape]),
49244937
device=self._device,

torchrl/envs/gym_like.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
342342
for key, val in TensorDict(obs_dict, []).items(True, True)
343343
)
344344
else:
345-
tensordict_out = TensorDict._new_unsafe(
345+
tensordict_out = TensorDict(
346346
obs_dict,
347347
batch_size=tensordict.batch_size,
348348
)
@@ -376,7 +376,8 @@ def _reset(
376376

377377
source = self.read_obs(obs)
378378

379-
tensordict_out = TensorDict._new_unsafe(
379+
# _new_unsafe cannot be used because it won't wrap non-tensor correctly
380+
tensordict_out = TensorDict(
380381
source=source,
381382
batch_size=self.batch_size,
382383
)

0 commit comments

Comments
 (0)