Skip to content

Commit 19cc931

Browse files
committed
Update
[ghstack-poisoned]
1 parent bef7a20 commit 19cc931

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

torchrl/modules/llm/transformers_policy.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
6-
# TODO: lazy imports
5+
from __future__ import annotations
76

87
import torch
98

10-
import transformers
119
from tensordict import NestedKey, TensorDictBase
1210
from tensordict.nn import (
1311
TensorDictModule as Mod,
@@ -17,7 +15,6 @@
1715
)
1816
from tensordict.tensorclass import NonTensorData, NonTensorStack
1917
from torchrl.data.llm import LLMData
20-
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel
2118

2219

2320
def _maybe_clear_device(td):
@@ -107,11 +104,12 @@ def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase:
107104

108105

109106
def from_hf_transformers(
110-
model: transformers.modeling_utils.PreTrainedModel,
107+
model: transformers.modeling_utils.PreTrainedModel, # noqa
111108
*,
112109
generate: bool = True,
113110
return_log_probs: bool = True,
114-
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None,
111+
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer
112+
| None = None, # noqa
115113
from_text: bool = False,
116114
device: torch.device | None = None,
117115
kwargs: dict | None = None,
@@ -404,6 +402,9 @@ def remove_input_seq(tokens_in, tokens_out):
404402

405403

406404
if __name__ == "__main__":
405+
import transformers
406+
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel
407+
407408
max_seq_length = 50000
408409

409410
tokenizer = AutoTokenizer.from_pretrained("gpt2")

torchrl/modules/llm/vllm_policy.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
57
import collections
8+
import importlib.util
69

710
import torch
8-
import transformers
9-
import vllm
1011
from tensordict import (
1112
from_dataclass,
1213
maybe_dense_stack,
@@ -22,9 +23,15 @@
2223
)
2324

2425
from torchrl.data import LLMData
25-
from vllm import LLM, SamplingParams
2626

27-
CompletionOutput_tc = from_dataclass(vllm.outputs.CompletionOutput)
27+
_has_vllm = importlib.util.find_spec("vllm")
28+
29+
if _has_vllm:
30+
import vllm
31+
32+
CompletionOutput_tc = from_dataclass(vllm.outputs.CompletionOutput)
33+
else:
34+
CompletionOutput_tc = None
2835

2936

3037
def _maybe_clear_device(td):
@@ -43,10 +50,11 @@ def _maybe_set_device(td):
4350

4451

4552
def from_vllm(
46-
model: LLM,
53+
model: vllm.LLM, # noqa
4754
*,
4855
return_log_probs: bool = False,
49-
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None,
56+
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer # noqa
57+
| None = None, # noqa
5058
from_text: bool = False,
5159
device: torch.device | None = None,
5260
generate: bool = True,
@@ -386,6 +394,8 @@ def from_request_output(cls, requests):
386394

387395

388396
if __name__ == "__main__":
397+
from vllm import LLM, SamplingParams
398+
389399
prompts = [
390400
"Hello, my name is",
391401
"The president of the United States is",

0 commit comments

Comments
 (0)