Skip to content

Commit bf3b79e

Browse files
authored
[VLM] Qwen2.5-VL
1 parent 9a5b155 commit bf3b79e

File tree

14 files changed

+1315
-52
lines changed

14 files changed

+1315
-52
lines changed

Diff for: docs/source/models/supported_models.md

+11
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,13 @@ See [this page](#generative-models) for more information on how to use generativ
846846
* ✅︎
847847
* ✅︎
848848
* ✅︎
849+
- * `Qwen2_5_VLForConditionalGeneration`
850+
* Qwen2.5-VL
851+
* T + I<sup>E+</sup> + V<sup>E+</sup>
852+
* `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc.
853+
*
854+
* ✅︎
855+
* ✅︎
849856
- * `UltravoxModel`
850857
* Ultravox
851858
* T + A<sup>E+</sup>
@@ -880,6 +887,10 @@ The chat template for Pixtral-HF is incorrect (see [discussion](https://huggingf
880887
A corrected version is available at <gh-file:examples/template_pixtral_hf.jinja>.
881888
:::
882889

890+
:::{note}
891+
To use Qwen2.5-VL series models, you have to install Huggingface `transformers` library from source via `pip install git+https://github.com/huggingface/transformers`.
892+
:::
893+
883894
### Pooling Models
884895

885896
See [this page](pooling-models) for more information on how to use pooling models.

Diff for: examples/offline_inference/vision_language.py

+31
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,36 @@ def run_qwen2_vl(question: str, modality: str):
531531
return llm, prompt, stop_token_ids
532532

533533

534+
# Qwen2.5-VL
535+
def run_qwen2_5_vl(question: str, modality: str):
536+
537+
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
538+
539+
llm = LLM(
540+
model=model_name,
541+
max_model_len=4096,
542+
max_num_seqs=5,
543+
mm_processor_kwargs={
544+
"min_pixels": 28 * 28,
545+
"max_pixels": 1280 * 28 * 28,
546+
"fps": 1,
547+
},
548+
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
549+
)
550+
551+
if modality == "image":
552+
placeholder = "<|image_pad|>"
553+
elif modality == "video":
554+
placeholder = "<|video_pad|>"
555+
556+
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
557+
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
558+
f"{question}<|im_end|>\n"
559+
"<|im_start|>assistant\n")
560+
stop_token_ids = None
561+
return llm, prompt, stop_token_ids
562+
563+
534564
model_example_map = {
535565
"aria": run_aria,
536566
"blip-2": run_blip2,
@@ -557,6 +587,7 @@ def run_qwen2_vl(question: str, modality: str):
557587
"pixtral_hf": run_pixtral_hf,
558588
"qwen_vl": run_qwen_vl,
559589
"qwen2_vl": run_qwen2_vl,
590+
"qwen2_5_vl": run_qwen2_5_vl,
560591
}
561592

562593

Diff for: examples/offline_inference/vision_language_multi_image.py

+58
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,63 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
392392
)
393393

394394

395+
def load_qwen2_5_vl(question, image_urls: List[str]) -> ModelRequestData:
396+
try:
397+
from qwen_vl_utils import process_vision_info
398+
except ModuleNotFoundError:
399+
print('WARNING: `qwen-vl-utils` not installed, input images will not '
400+
'be automatically resized. You can enable this functionality by '
401+
'`pip install qwen-vl-utils`.')
402+
process_vision_info = None
403+
404+
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
405+
406+
llm = LLM(
407+
model=model_name,
408+
max_model_len=32768 if process_vision_info is None else 4096,
409+
max_num_seqs=5,
410+
limit_mm_per_prompt={"image": len(image_urls)},
411+
)
412+
413+
placeholders = [{"type": "image", "image": url} for url in image_urls]
414+
messages = [{
415+
"role": "system",
416+
"content": "You are a helpful assistant."
417+
}, {
418+
"role":
419+
"user",
420+
"content": [
421+
*placeholders,
422+
{
423+
"type": "text",
424+
"text": question
425+
},
426+
],
427+
}]
428+
429+
processor = AutoProcessor.from_pretrained(model_name)
430+
431+
prompt = processor.apply_chat_template(messages,
432+
tokenize=False,
433+
add_generation_prompt=True)
434+
435+
stop_token_ids = None
436+
437+
if process_vision_info is None:
438+
image_data = [fetch_image(url) for url in image_urls]
439+
else:
440+
image_data, _ = process_vision_info(messages,
441+
return_video_sample_fps=False)
442+
443+
return ModelRequestData(
444+
llm=llm,
445+
prompt=prompt,
446+
stop_token_ids=stop_token_ids,
447+
image_data=image_data,
448+
chat_template=None,
449+
)
450+
451+
395452
model_example_map = {
396453
"aria": load_aria,
397454
"deepseek_vl_v2": load_deepseek_vl2,
@@ -404,6 +461,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
404461
"pixtral_hf": load_pixtral_hf,
405462
"qwen_vl_chat": load_qwen_vl_chat,
406463
"qwen2_vl": load_qwen2_vl,
464+
"qwen2_5_vl": load_qwen2_5_vl,
407465
}
408466

409467

Diff for: tests/models/decoder_only/vision_language/test_models.py

+22
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@
121121
else ("half", "float")),
122122
marks=[pytest.mark.core_model],
123123
),
124+
# TODO(ywang96): Move Qwen2-VL out of core models in favor of Qwen2.5-VL
125+
# once we upgraded to transformers>=4.49.0.
124126
"qwen2_vl": VLMTestInfo(
125127
models=["Qwen/Qwen2-VL-2B-Instruct"],
126128
test_type=(
@@ -138,6 +140,26 @@
138140
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
139141
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
140142
),
143+
"qwen2_5_vl": VLMTestInfo(
144+
models=["Qwen/Qwen2.5-VL-3B-Instruct"],
145+
test_type=(
146+
VLMTestType.IMAGE,
147+
VLMTestType.MULTI_IMAGE,
148+
VLMTestType.VIDEO
149+
),
150+
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
151+
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
152+
video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501
153+
max_model_len=4096,
154+
max_num_seqs=2,
155+
auto_cls=AutoModelForVision2Seq,
156+
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
157+
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
158+
marks=[pytest.mark.skipif(
159+
TRANSFORMERS_VERSION < "4.49.0",
160+
reason="HF model requires transformers>=4.49.0",
161+
), pytest.mark.core_model, pytest.mark.cpu_model],
162+
),
141163
#### Extended model tests
142164
"aria": VLMTestInfo(
143165
models=["rhymes-ai/Aria"],

Diff for: tests/models/multimodal/processing/test_common.py

+1
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def _test_processing_correctness(
161161
"nvidia/NVLM-D-72B",
162162
"Qwen/Qwen-VL-Chat",
163163
"Qwen/Qwen2-VL-2B-Instruct",
164+
"Qwen/Qwen2.5-VL-3B-Instruct",
164165
"Qwen/Qwen2-Audio-7B-Instruct",
165166
"fixie-ai/ultravox-v0_3",
166167
])

Diff for: tests/models/registry.py

+2
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ def check_available_online(
264264
trust_remote_code=True),
265265
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
266266
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
267+
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
268+
min_transformers_version="4.49"), # noqa: E501
267269
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3",
268270
trust_remote_code=True),
269271
# [Encoder-decoder]

Diff for: vllm/entrypoints/chat_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def _placeholder_str(self, modality: ModalityStr,
410410
return "<image>"
411411
if model_type == "mllama":
412412
return "<|image|>"
413-
if model_type == "qwen2_vl":
413+
if model_type in ("qwen2_vl", "qwen2_5_vl"):
414414
return "<|vision_start|><|image_pad|><|vision_end|>"
415415
if model_type == "molmo":
416416
return ""
@@ -430,7 +430,7 @@ def _placeholder_str(self, modality: ModalityStr,
430430
return "(<audio>./</audio>)"
431431
raise TypeError(f"Unknown model type: {model_type}")
432432
elif modality == "video":
433-
if model_type == "qwen2_vl":
433+
if model_type in ("qwen2_vl", "qwen2_5_vl"):
434434
return "<|vision_start|><|video_pad|><|vision_end|>"
435435
if model_type in ("minicpmo", "minicpmv"):
436436
return "(<video>./</video>)"

Diff for: vllm/model_executor/layers/rotary_embedding.py

+34-24
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import torch
2929
import torch.nn as nn
30+
from transformers import PretrainedConfig
3031

3132
from vllm.model_executor.custom_op import CustomOp
3233

@@ -772,8 +773,12 @@ def __init__(
772773
dtype: torch.dtype,
773774
mrope_section: Optional[List[int]] = None,
774775
) -> None:
775-
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
776-
is_neox_style, dtype)
776+
# In Qwen2.5-VL, the maximum index value is related to the duration of
777+
# the input video. We enlarge max_position_embeddings to 4 times to get
778+
# a larger the cos and sin cache.
779+
self.cache_max_position_num = max_position_embeddings * 4
780+
super().__init__(head_size, rotary_dim, self.cache_max_position_num,
781+
base, is_neox_style, dtype)
777782

778783
self.mrope_section = mrope_section
779784
if self.mrope_section:
@@ -831,49 +836,47 @@ def forward(
831836
@staticmethod
832837
def get_input_positions(
833838
input_tokens: List[int],
839+
hf_config: PretrainedConfig,
834840
image_grid_thw: Union[List[List[int]], torch.Tensor],
835841
video_grid_thw: Union[List[List[int]], torch.Tensor],
836-
image_token_id: int,
837-
video_token_id: int,
838-
vision_start_token_id: int,
839-
vision_end_token_id: int,
840-
spatial_merge_size: int,
842+
second_per_grid_ts: Optional[List[float]] = None,
841843
context_len: int = 0,
842844
seq_len: Optional[int] = None,
843845
) -> Tuple[List[List[int]], int]:
844846
"""Get mrope input positions and delta value."""
845847

846848
llm_positions, mrope_position_delta = \
847849
MRotaryEmbedding.get_input_positions_tensor(
848-
input_tokens,
849-
image_grid_thw,
850-
video_grid_thw,
851-
image_token_id,
852-
video_token_id,
853-
vision_start_token_id,
854-
vision_end_token_id,
855-
spatial_merge_size,
856-
context_len,
857-
seq_len,
850+
input_tokens=input_tokens,
851+
hf_config=hf_config,
852+
image_grid_thw=image_grid_thw,
853+
video_grid_thw=video_grid_thw,
854+
second_per_grid_ts=second_per_grid_ts,
855+
context_len=context_len,
856+
seq_len=seq_len,
858857
)
859858

860859
return llm_positions.tolist(), mrope_position_delta
861860

862861
@staticmethod
863862
def get_input_positions_tensor(
864863
input_tokens: List[int],
864+
hf_config: PretrainedConfig,
865865
image_grid_thw: Union[List[List[int]], torch.Tensor],
866866
video_grid_thw: Union[List[List[int]], torch.Tensor],
867-
image_token_id: int,
868-
video_token_id: int,
869-
vision_start_token_id: int,
870-
vision_end_token_id: int,
871-
spatial_merge_size: int,
867+
second_per_grid_ts: Optional[List[float]] = None,
872868
context_len: int = 0,
873869
seq_len: Optional[int] = None,
874870
) -> Tuple[torch.Tensor, int]:
875871
"""Get mrope input positions and delta value."""
876872

873+
image_token_id = hf_config.image_token_id
874+
video_token_id = hf_config.video_token_id
875+
vision_start_token_id = hf_config.vision_start_token_id
876+
spatial_merge_size = hf_config.vision_config.spatial_merge_size
877+
tokens_per_second = getattr(hf_config.vision_config,
878+
"tokens_per_second", 1.0)
879+
877880
if isinstance(image_grid_thw, torch.Tensor):
878881
image_grid_thw = image_grid_thw.tolist()
879882
if isinstance(video_grid_thw, torch.Tensor):
@@ -892,6 +895,7 @@ def get_input_positions_tensor(
892895

893896
image_index, video_index = 0, 0
894897
for _ in range(image_nums + video_nums):
898+
video_second_per_grid_t = 0.0
895899
if image_token_id in input_tokens and remain_images > 0:
896900
ed_image = input_tokens.index(image_token_id, st)
897901
else:
@@ -915,9 +919,13 @@ def get_input_positions_tensor(
915919
video_grid_thw[video_index][1],
916920
video_grid_thw[video_index][2],
917921
)
922+
video_second_per_grid_t = 1.0
923+
if second_per_grid_ts is not None:
924+
video_second_per_grid_t = second_per_grid_ts[video_index]
918925
video_index += 1
919926
remain_videos -= 1
920927
ed = ed_video
928+
921929
llm_grid_t, llm_grid_h, llm_grid_w = \
922930
t, h // spatial_merge_size, w // spatial_merge_size
923931
text_len = ed - st
@@ -927,8 +935,10 @@ def get_input_positions_tensor(
927935
llm_pos_ids_list.append(
928936
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
929937

930-
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
931-
-1, llm_grid_h * llm_grid_w).flatten()
938+
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
939+
-1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
940+
tokens_per_second).long().flatten()
941+
932942
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
933943
llm_grid_t, -1, llm_grid_w).flatten()
934944
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(

0 commit comments

Comments
 (0)