diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 54e01b00718..db5a70bd1d9 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -19,6 +19,7 @@ import torch from tensordict import ( is_tensor_collection, + lazy_stack, LazyStackedTensorDict, TensorDict, TensorDictBase, @@ -430,7 +431,7 @@ def get(self, index: int | Sequence[int] | slice) -> Any: stack_dim = self.stack_dim if stack_dim < 0: stack_dim = out[0].ndim + 1 + stack_dim - out = LazyStackedTensorDict(*out, stack_dim=stack_dim) + out = lazy_stack(list(out), stack_dim=stack_dim) return out return out