Skip to content

Commit b204dc8

Browse files
committed
[Feature] vllm wrapper
ghstack-source-id: 5e8c197 Pull Request resolved: #2830
1 parent 919af4a commit b204dc8

File tree

1 file changed

+129
-0
lines changed

1 file changed

+129
-0
lines changed

torchrl/modules/llm/vllm.py

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
8+
import transformers
9+
from tensordict import NestedKey, NonTensorData, NonTensorStack, TensorDict
10+
from tensordict.nn import (
11+
TensorDictModule as Mod,
12+
TensorDictModuleBase,
13+
TensorDictSequential as Seq,
14+
)
15+
from transformers import AutoTokenizer
16+
from vllm import LLM, SamplingParams
17+
18+
19+
def _maybe_clear_device(td):
20+
if td.device is None:
21+
return td
22+
return td.set(NonTensorData("_source_device"), td.device).clear_device_()
23+
24+
25+
def _maybe_set_device(td):
26+
device = td.pop("_source_device", None)
27+
if device is None:
28+
return td
29+
elif isinstance(device, NonTensorData):
30+
device: torch.device = device.data
31+
return td.to(device)
32+
33+
34+
def from_vllm(
35+
model: LLM,
36+
return_log_probs: bool = False,
37+
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None,
38+
from_text: bool = False,
39+
device: torch.device | None = None,
40+
text_key: NestedKey = "text",
41+
generate_kwargs: dict | None = None,
42+
tokenizer_kwargs: dict | None = None,
43+
) -> TensorDictModuleBase:
44+
# TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks
45+
module_dict = {}
46+
if device:
47+
module_dict["clear_device"] = _maybe_clear_device
48+
if from_text:
49+
if not tokenizer_kwargs:
50+
tokenizer_kwargs = {}
51+
if not tokenizer_kwargs.setdefault("return_attention_mask", True):
52+
raise RuntimeError
53+
if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt":
54+
raise RuntimeError
55+
if tokenizer_kwargs.setdefault("padding", True) not in (True,):
56+
raise RuntimeError
57+
if tokenizer_kwargs.setdefault("padding_side", "left") != "left":
58+
raise RuntimeError
59+
module_dict["encode"] = Mod(
60+
tokenizer,
61+
in_keys=[text_key],
62+
out_keys=["tokens_in"], # method_kwargs=tokenizer_kwargs,
63+
strict=True,
64+
)
65+
66+
# FIXME: this is not great!
67+
def f(td):
68+
td["tokens_in", "input_ids"] = NonTensorStack(
69+
*td["tokens_in", "input_ids"].tolist()
70+
)
71+
print("td['tokens_in', 'input_ids']", td["tokens_in", "input_ids"])
72+
return td
73+
74+
module_dict["to_list"] = f
75+
76+
if generate_kwargs is None:
77+
generate_kwargs = {
78+
"detokenize": False,
79+
"prompt_logprobs": return_log_probs,
80+
"logprobs": return_log_probs,
81+
}
82+
sampling_params = SamplingParams(**generate_kwargs)
83+
84+
module_dict["generate"] = Mod(
85+
model,
86+
method="generate",
87+
method_kwargs={"sampling_params": sampling_params},
88+
in_keys={
89+
"prompt_token_ids": ("tokens_in", "input_ids"),
90+
# "attention_mask": ("tokens_in", "attention_mask"),
91+
},
92+
out_keys=["tokens_out"],
93+
out_to_in_map=True,
94+
strict=True,
95+
)
96+
97+
def get_output_tokens_and_log_probs(td):
98+
# FIXME: shouldn't have to be doing 0 index here to make sure this works with batches
99+
td["output_tokens"] = td["tokens_out"][0].outputs[0].token_ids
100+
# FIXME: this is not in a tensor form yet but uses their own LogProb object
101+
td["log_probs"] = td["tokens_out"][0].outputs[0].logprobs
102+
return td
103+
104+
module_dict["get_output_tokens_and_log_probs"] = get_output_tokens_and_log_probs
105+
106+
# module_dict["extract_log_probs"] = WrapModule(log_probs_from_logits, in_keys=[("tokens_in", "sequences"), ("tokens_in", "scores")], out_keys=["logits", "log_probs"])
107+
if from_text:
108+
module_dict["decode"] = Mod(
109+
tokenizer.batch_decode,
110+
in_keys=["output_tokens"], # in_keys=["tokens_out", "sequences"],
111+
out_keys=["action"], # strict=True,
112+
)
113+
if device:
114+
module_dict["to_source_device"] = _maybe_set_device
115+
116+
return Seq(module_dict)
117+
118+
119+
if __name__ == "__main__":
120+
max_seq_length = 50000
121+
model_name = "Qwen/Qwen2.5-7B-Instruct"
122+
model = LLM(model_name, skip_tokenizer_init=True, device="cuda:0")
123+
model.llm_engine.model_executor.driver_worker.worker.model_runner.model.sampler.include_gpu_probs_tensor = (
124+
True
125+
)
126+
tokenizer = AutoTokenizer.from_pretrained(model_name, device="cuda:0")
127+
# tokenizer.padding_side = "left"
128+
m = from_vllm(model, tokenizer=tokenizer, from_text=True, device="cuda:0")
129+
print(m(TensorDict(text="a text is a text")))

0 commit comments

Comments
 (0)