Skip to content

Commit d90b9e3

Browse files
committed
[BugFix] Fix imports
ghstack-source-id: db85f26 Pull Request resolved: #2605
1 parent a1e21f5 commit d90b9e3

File tree

4 files changed

+32
-21
lines changed

4 files changed

+32
-21
lines changed

benchmarks/test_objectives_benchmarks.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1152,4 +1152,7 @@ def loss_and_bw(td):
11521152

11531153
if __name__ == "__main__":
11541154
args, unknown = argparse.ArgumentParser().parse_known_args()
1155-
pytest.main([__file__, "--capture", "no", "--exitfirst", "--benchmark-group-by", "func"] + unknown)
1155+
pytest.main(
1156+
[__file__, "--capture", "no", "--exitfirst", "--benchmark-group-by", "func"]
1157+
+ unknown
1158+
)

setup.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,18 @@ def _main(argv):
209209
"dm_control": ["dm_control"],
210210
"gym_continuous": ["gymnasium<1.0", "mujoco"],
211211
"rendering": ["moviepy<2.0.0"],
212-
"tests": ["pytest", "pyyaml", "pytest-instafail", "scipy"],
212+
"tests": [
213+
"pytest",
214+
"pyyaml",
215+
"pytest-instafail",
216+
"scipy",
217+
"pytest-mock",
218+
"pytest-cov",
219+
"pytest-benchmark",
220+
"pytest-rerunfailures",
221+
"pytest-error-for-skips",
222+
"",
223+
],
213224
"utils": [
214225
"tensorboard",
215226
"wandb",

test/test_rlhf.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -298,12 +298,10 @@ def test_tensordict_tokenizer(
298298
"Lettuce in, it's cold out here!",
299299
]
300300
}
301-
if not truncation and return_tensordict and max_length == 10:
302-
with pytest.raises(ValueError, match="TensorDict conversion only supports"):
303-
out = process(example)
304-
return
305301
out = process(example)
306-
if return_tensordict:
302+
if not truncation and return_tensordict and max_length == 10:
303+
assert out.get("input_ids").shape[-1] == -1
304+
elif return_tensordict:
307305
assert out.get("input_ids").shape[-1] == max_length
308306
else:
309307
obj = out.get("input_ids")
@@ -346,12 +344,10 @@ def test_prompt_tensordict_tokenizer(
346344
],
347345
"label": ["right", "wrong", "right", "wrong", "right"],
348346
}
349-
if not truncation and return_tensordict and max_length == 10:
350-
with pytest.raises(ValueError, match="TensorDict conversion only supports"):
351-
out = process(example)
352-
return
353347
out = process(example)
354-
if return_tensordict:
348+
if not truncation and return_tensordict and max_length == 10:
349+
assert out.get("input_ids").shape[-1] == -1
350+
elif return_tensordict:
355351
assert out.get("input_ids").shape[-1] == max_length
356352
else:
357353
obj = out.get("input_ids")

torchrl/envs/common.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from __future__ import annotations
77

88
import abc
9-
import functools
109
import warnings
1110
from copy import deepcopy
11+
from functools import partial, wraps
1212
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
1313

1414
import numpy as np
@@ -33,6 +33,7 @@
3333
_StepMDP,
3434
_terminated_or_truncated,
3535
_update_during_reset,
36+
check_env_specs as check_env_specs_func,
3637
get_available_libraries,
3738
)
3839

@@ -2035,7 +2036,7 @@ def _register_gym(
20352036

20362037
if entry_point is None:
20372038
entry_point = cls
2038-
entry_point = functools.partial(
2039+
entry_point = partial(
20392040
_TorchRLGymWrapper,
20402041
entry_point=entry_point,
20412042
info_keys=info_keys,
@@ -2084,7 +2085,7 @@ def _register_gym( # noqa: F811
20842085

20852086
if entry_point is None:
20862087
entry_point = cls
2087-
entry_point = functools.partial(
2088+
entry_point = partial(
20882089
_TorchRLGymWrapper,
20892090
entry_point=entry_point,
20902091
info_keys=info_keys,
@@ -2138,7 +2139,7 @@ def _register_gym( # noqa: F811
21382139

21392140
if entry_point is None:
21402141
entry_point = cls
2141-
entry_point = functools.partial(
2142+
entry_point = partial(
21422143
_TorchRLGymWrapper,
21432144
entry_point=entry_point,
21442145
info_keys=info_keys,
@@ -2195,7 +2196,7 @@ def _register_gym( # noqa: F811
21952196

21962197
if entry_point is None:
21972198
entry_point = cls
2198-
entry_point = functools.partial(
2199+
entry_point = partial(
21992200
_TorchRLGymWrapper,
22002201
entry_point=entry_point,
22012202
info_keys=info_keys,
@@ -2254,7 +2255,7 @@ def _register_gym( # noqa: F811
22542255
)
22552256
if entry_point is None:
22562257
entry_point = cls
2257-
entry_point = functools.partial(
2258+
entry_point = partial(
22582259
_TorchRLGymWrapper,
22592260
entry_point=entry_point,
22602261
info_keys=info_keys,
@@ -2293,7 +2294,7 @@ def _register_gym( # noqa: F811
22932294
if entry_point is None:
22942295
entry_point = cls
22952296

2296-
entry_point = functools.partial(
2297+
entry_point = partial(
22972298
_TorchRLGymnasiumWrapper,
22982299
entry_point=entry_point,
22992300
info_keys=info_keys,
@@ -3422,11 +3423,11 @@ def _get_sync_func(policy_device, env_device):
34223423
if policy_device is not None and policy_device.type == "cuda":
34233424
if env_device is None or env_device.type == "cuda":
34243425
return torch.cuda.synchronize
3425-
return functools.partial(torch.cuda.synchronize, device=policy_device)
3426+
return partial(torch.cuda.synchronize, device=policy_device)
34263427
if env_device is not None and env_device.type == "cuda":
34273428
if policy_device is None:
34283429
return torch.cuda.synchronize
3429-
return functools.partial(torch.cuda.synchronize, device=env_device)
3430+
return partial(torch.cuda.synchronize, device=env_device)
34303431
return torch.cuda.synchronize
34313432
if torch.backends.mps.is_available():
34323433
return torch.mps.synchronize

0 commit comments

Comments
 (0)