1919
2020from torchrl ._utils import _ContextManager , _DecoratorContextManager
2121from torchrl .data .tensor_specs import Unbounded
22- from torchrl .objectives .value .functional import (
23- _inv_pad_sequence ,
24- _split_and_pad_sequence ,
25- )
26- from torchrl .objectives .value .utils import _get_num_per_traj_init
2722
2823
2924class LSTMCell (RNNCellBase ):
@@ -718,6 +713,11 @@ def set_recurrent_mode(self, mode: bool = True):
718713
719714 @dispatch
720715 def forward (self , tensordict : TensorDictBase ):
716+ from torchrl .objectives .value .functional import (
717+ _inv_pad_sequence ,
718+ _split_and_pad_sequence ,
719+ )
720+
721721 # we want to get an error if the value input is missing, but not the hidden states
722722 defaults = [NO_DEFAULT , None , None ]
723723 shape = tensordict .shape
@@ -742,6 +742,8 @@ def forward(self, tensordict: TensorDictBase):
742742 is_init = tensordict_shaped ["is_init" ].squeeze (- 1 )
743743 splits = None
744744 if self .recurrent_mode and is_init [..., 1 :].any ():
745+ from torchrl .objectives .value .utils import _get_num_per_traj_init
746+
745747 # if we have consecutive trajectories, things get a little more complicated
746748 # we have a tensordict of shape [B, T]
747749 # we will split / pad things such that we get a tensordict of shape
@@ -1533,6 +1535,11 @@ def set_recurrent_mode(self, mode: bool = True):
15331535 @dispatch
15341536 @set_lazy_legacy (False )
15351537 def forward (self , tensordict : TensorDictBase ):
1538+ from torchrl .objectives .value .functional import (
1539+ _inv_pad_sequence ,
1540+ _split_and_pad_sequence ,
1541+ )
1542+
15361543 # we want to get an error if the value input is missing, but not the hidden states
15371544 defaults = [NO_DEFAULT , None ]
15381545 shape = tensordict .shape
@@ -1557,6 +1564,8 @@ def forward(self, tensordict: TensorDictBase):
15571564 is_init = tensordict_shaped ["is_init" ].squeeze (- 1 )
15581565 splits = None
15591566 if self .recurrent_mode and is_init [..., 1 :].any ():
1567+ from torchrl .objectives .value .utils import _get_num_per_traj_init
1568+
15601569 # if we have consecutive trajectories, things get a little more complicated
15611570 # we have a tensordict of shape [B, T]
15621571 # we will split / pad things such that we get a tensordict of shape
0 commit comments