Skip to content

Commit e80732e

Browse files
committed
[BugFix] Fix lazy-stack in RBs
ghstack-source-id: 38399ee Pull Request resolved: #2880
1 parent 48111e9 commit e80732e

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

torchrl/data/replay_buffers/storages.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def get(self, index: int | Sequence[int] | slice) -> Any:
431431
stack_dim = self.stack_dim
432432
if stack_dim < 0:
433433
stack_dim = out[0].ndim + 1 + stack_dim
434-
out = lazy_stack(list(out), stack_dim=stack_dim)
434+
out = lazy_stack(list(out), stack_dim)
435435
return out
436436
return out
437437

torchrl/data/replay_buffers/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818
import torch
1919
from tensordict import (
20-
LazyStackedTensorDict,
20+
lazy_stack,
2121
MemoryMappedTensor,
2222
NonTensorData,
2323
TensorDict,
@@ -482,7 +482,7 @@ def __call__(self, data: TensorDictBase, out: TensorDictBase = None):
482482
out._get_sub_tensordict((slice(None),) * i + (j,))
483483
for j in range(out.shape[i])
484484
]
485-
out = LazyStackedTensorDict(*out_list, stack_dim=i)
485+
out = lazy_stack(out_list, i)
486486

487487
# Create a function that reads slices of the input data
488488
with out.flatten(1, -1) if out.ndim > 2 else contextlib.nullcontext(

0 commit comments

Comments
 (0)