|
6 | 6 |
|
7 | 7 | from collections import deque |
8 | 8 | from collections.abc import Mapping |
9 | | -from copy import copy, deepcopy |
| 9 | +from copy import copy |
10 | 10 | from typing import Any, Callable, Iterable, Literal |
11 | 11 |
|
12 | 12 | import torch |
13 | 13 | from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase, unravel_key |
14 | | -from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams |
| 14 | +from tensordict.nn import ( |
| 15 | + ProbabilisticTensorDictModule, |
| 16 | + ProbabilisticTensorDictSequential, |
| 17 | +) |
15 | 18 | from tensordict.utils import _zip_strict, is_seq_of_nested_key |
16 | | -from torch import nn |
17 | 19 |
|
18 | 20 | from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded |
19 | 21 | from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform |
20 | | -from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param |
| 22 | +from torchrl.envs.transforms.utils import _set_missing_tolerance |
21 | 23 | from torchrl.envs.utils import make_composite_from_td |
22 | 24 |
|
23 | 25 |
|
@@ -500,6 +502,10 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None): |
500 | 502 | return self._queue.popleft() |
501 | 503 | return out |
502 | 504 |
|
| 505 | + def __repr__(self) -> str: |
| 506 | + class_name = self.__class__.__name__ |
| 507 | + return f"{class_name}(primers={self.primers}, dataloader={self.dataloader})" |
| 508 | + |
503 | 509 |
|
504 | 510 | class KLRewardTransform(Transform): |
505 | 511 | """A transform to add a KL[pi_current||pi_0] correction term to the reward. |
|
0 commit comments