Skip to content

Commit 2b9f8f9

Browse files
committed
Update (base update)
[ghstack-poisoned]
1 parent d5fb295 commit 2b9f8f9

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

test/test_cost.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
)
4141
from tensordict.nn.distributions.composite import _add_suffix
4242
from tensordict.nn.utils import Buffer
43-
from tensordict.utils import unravel_key
43+
from tensordict.utils import set_capture_non_tensor_stack, unravel_key
4444
from torch import autograd, nn
4545
from torchrl._utils import _standardize
4646
from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded
@@ -16664,6 +16664,7 @@ def forward(self, td, mode):
1666416664

1666516665

1666616666
class TestPPO4LLMs:
16667+
@set_capture_non_tensor_stack(False)
1666716668
@pytest.mark.parametrize("from_text", [True, False])
1666816669
def test_hf(self, from_text):
1666916670
from torchrl.envs import LLMEnv, Transform

0 commit comments

Comments
 (0)