Skip to content

Commit 84ab88d

Browse files
authored
Support flashinfer for Gemma3 prefill (#3167)
* launcher: ensure correct detection of Gemma 3 head size * Support flashinfer for Gemma3 prefill Gemma3 uses bidirectional attention for images. Flashinfer supports custom masks. Hook up the mask with flashinfer, so that we do not have to use the slower SDPA implementation for prefills with images. * Update Gemma3 test outputs * Fixed unused import
1 parent 4645678 commit 84ab88d

File tree

10 files changed

+141
-109
lines changed

10 files changed

+141
-109
lines changed

integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3.json

+82-82
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
{
22
"choices": [
33
{
4-
"finish_reason": "stop",
4+
"finish_reason": "length",
55
"index": 0,
66
"logprobs": null,
77
"message": {
8-
"content": "Okay, let's analyze the image. \n\nThe image is entirely white, with a very subtle, faint outline of a stylized, cartoonish figure. It appears to be a simplified depiction of a person, likely a child, with a wide-eyed expression and a small, rounded body. \n\nIt's almost like a minimalist, iconic representation. \n\nDo you want me to try and describe it in more detail or perhaps speculate about the context of the image?",
8+
"content": "Okay, let's analyze the image. \n\nThe image is entirely white, with a very subtle, faint outline of a stylized, cartoonish figure. It appears to be a simplified depiction of a person, likely a child, with a wide-eyed expression and a small, rounded body. \n\nIt's almost like a minimalist, iconic representation. \n\nDo you want me to try and describe it in more detail, or perhaps suggest what this image might represent (e.g",
99
"name": null,
1010
"role": "assistant",
1111
"tool_calls": null
1212
},
1313
"usage": null
1414
}
1515
],
16-
"created": 1741965892,
16+
"created": 1744396706,
1717
"id": "",
1818
"model": "google/gemma-3-4b-it",
1919
"object": "chat.completion",
20-
"system_fingerprint": "3.2.1-dev0-native",
20+
"system_fingerprint": "3.2.3-dev0-native",
2121
"usage": {
22-
"completion_tokens": 98,
22+
"completion_tokens": 100,
2323
"prompt_tokens": 277,
24-
"total_tokens": 375
24+
"total_tokens": 377
2525
}
2626
}

integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_base64_rgba.json

+5-5
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,22 @@
55
"index": 0,
66
"logprobs": null,
77
"message": {
8-
"content": "Okay, let's analyze the image. \n\nThe transparent image reveals a stylized depiction of **a human head**. It's a minimalist, geometric representation, showing the basic shapes of the skull, eye sockets, and head outline. \n\nDo you want me to describe any specific element of the image in more detail?",
8+
"content": "Okay, let's analyze the image. \n\nThe transparent image reveals a stylized depiction of **a human head**. It's a minimalist, geometric representation, showing the basic shapes of the skull, eye sockets, and head outline. \n\nIf you'd like, you can give me more details about the image or ask me to focus on a specific aspect of it.",
99
"name": null,
1010
"role": "assistant",
1111
"tool_calls": null
1212
},
1313
"usage": null
1414
}
1515
],
16-
"created": 1741966313,
16+
"created": 1744396703,
1717
"id": "",
1818
"model": "google/gemma-3-4b-it",
1919
"object": "chat.completion",
20-
"system_fingerprint": "3.2.1-dev0-native",
20+
"system_fingerprint": "3.2.3-dev0-native",
2121
"usage": {
22-
"completion_tokens": 67,
22+
"completion_tokens": 78,
2323
"prompt_tokens": 277,
24-
"total_tokens": 344
24+
"total_tokens": 355
2525
}
2626
}

integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
"usage": null
1414
}
1515
],
16-
"created": 1741964480,
16+
"created": 1744396699,
1717
"id": "",
1818
"model": "google/gemma-3-4b-it",
1919
"object": "chat.completion",
20-
"system_fingerprint": "3.2.1-dev0-native",
20+
"system_fingerprint": "3.2.3-dev0-native",
2121
"usage": {
2222
"completion_tokens": 74,
2323
"prompt_tokens": 275,

integration-tests/models/__snapshots__/test_flash_gemma3/test_flash_gemma3_image_cow_dog.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
"usage": null
1414
}
1515
],
16-
"created": 1741964477,
16+
"created": 1744396697,
1717
"id": "",
1818
"model": "google/gemma-3-4b-it",
1919
"object": "chat.completion",
20-
"system_fingerprint": "3.2.1-dev0-native",
20+
"system_fingerprint": "3.2.3-dev0-native",
2121
"usage": {
2222
"completion_tokens": 75,
2323
"prompt_tokens": 279,

launcher/src/main.rs

+16-5
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,22 @@ struct Config {
260260

261261
impl Config {
262262
fn get_head_dim(&self) -> Option<usize> {
263-
self.head_dim.or_else(|| {
264-
self.text_config
265-
.as_ref()
266-
.and_then(|text_config| text_config.head_dim)
267-
})
263+
if let Some(head_dim) = self.head_dim {
264+
return Some(head_dim);
265+
}
266+
267+
let text_config = self.text_config.as_ref()?;
268+
if let Some(head_size) = text_config.head_dim {
269+
return Some(head_size);
270+
}
271+
272+
match self.model_type.as_deref() {
273+
// We special-case gemma3 here, since we need flashinfer for
274+
// handling bidirectional masks. And flashinfer can only be
275+
// used when the head size is known.
276+
Some("gemma3") => Some(256),
277+
_ => None,
278+
}
268279
}
269280

270281
fn flop(&self) -> Option<u64> {

server/text_generation_server/layers/attention/flashinfer.py

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def use_prefill_with_paged_kv_state(
4545
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
4646
block_tables: torch.Tensor,
4747
cu_seqlens: torch.Tensor,
48+
custom_mask: Optional[torch.Tensor],
4849
input_lengths: torch.Tensor,
4950
num_heads: int,
5051
num_kv_heads: int,
@@ -88,6 +89,7 @@ def use_prefill_with_paged_kv_state(
8889
paged_kv_indptr=indptr,
8990
paged_kv_indices=block_tables,
9091
paged_kv_last_page_len=last_page_len,
92+
custom_mask=custom_mask,
9193
num_qo_heads=num_heads,
9294
num_kv_heads=num_kv_heads,
9395
head_dim=head_size,

server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from text_generation_server.layers.layernorm import (
4646
FastRMSNorm,
4747
)
48+
from text_generation_server.models.globals import ATTENTION
4849
from text_generation_server.utils.weights import UnquantizedWeight
4950
from transformers.activations import ACT2FN
5051
from text_generation_server.layers.attention import (
@@ -248,7 +249,7 @@ def forward(
248249

249250
# Prefill
250251
if cu_seqlen_prefill is not None:
251-
if attention_mask is None:
252+
if attention_mask is None or ATTENTION == "flashinfer":
252253
# flash attention
253254
attn_output = attention(
254255
query=query,
@@ -701,8 +702,16 @@ def __init__(self, prefix, config, weights):
701702
)
702703

703704
def get_attention_mask(
704-
self, input_ids, max_s, cu_seqlen_prefill, dtype, image_token_mask
705+
self,
706+
input_ids: torch.Tensor,
707+
cu_seqlen_prefill: Optional[torch.Tensor],
708+
dtype: torch.dtype,
709+
bool_mask: bool = False,
705710
):
711+
image_token_mask = (input_ids == self.config.image_token_index).to(
712+
input_ids.device
713+
)
714+
706715
device = input_ids.device
707716
min_dtype = torch.finfo(dtype).min
708717

@@ -748,9 +757,10 @@ def get_attention_mask(
748757
)
749758
full_attention_mask[:, :, :, :sequence_length] = combined_mask
750759

751-
final_attention_mask = torch.where(full_attention_mask, 0, min_dtype).to(device)
752-
753-
return final_attention_mask
760+
if bool_mask:
761+
return full_attention_mask
762+
else:
763+
return torch.where(full_attention_mask, 0, min_dtype).to(device)
754764

755765
def forward(
756766
self,
@@ -793,10 +803,8 @@ def forward(
793803
)
794804
attention_mask = self.get_attention_mask(
795805
input_ids,
796-
max_s,
797806
cu_seqlen_prefill,
798807
inputs_embeds.dtype,
799-
image_token_mask,
800808
)
801809
# Use flash attention for text-only input
802810
# else:

server/text_generation_server/models/flash_causal_lm.py

+2
Original file line numberDiff line numberDiff line change
@@ -2434,6 +2434,7 @@ def _forward_context(
24342434
input_lengths_tensor: torch.Tensor,
24352435
cache_lengths_tensor: torch.Tensor,
24362436
state: Optional[Any] = None,
2437+
attention_mask: Optional[torch.Tensor] = None,
24372438
) -> ContextManager:
24382439
if ATTENTION != "flashinfer":
24392440
return nullcontext()
@@ -2450,6 +2451,7 @@ def _forward_context(
24502451
),
24512452
block_tables=block_tables,
24522453
cu_seqlens=cu_seqlen_prefill,
2454+
custom_mask=attention_mask,
24532455
input_lengths=input_lengths_tensor + cache_lengths_tensor,
24542456
num_heads=self.num_heads,
24552457
num_kv_heads=self.num_kv_heads,

server/text_generation_server/models/vlm_causal_lm.py

+9
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,14 @@ def forward(
485485
)
486486
batch.position_ids = position_ids
487487

488+
if self.model.config.model_type == "gemma3" and cu_seqlen_prefill is not None:
489+
# Get the mask, needed for flashinfer.
490+
attention_mask = self.model.get_attention_mask(
491+
input_ids, cu_seqlen_prefill, self.dtype, bool_mask=True
492+
).reshape(-1)
493+
else:
494+
attention_mask = None
495+
488496
# Try to find an associated cuda graph
489497
bs = input_ids.shape[0]
490498
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
@@ -508,6 +516,7 @@ def forward(
508516
cu_seqlen_prefill=cu_seqlen_prefill,
509517
input_lengths_tensor=input_lengths,
510518
cache_lengths_tensor=cache_lengths_tensor,
519+
attention_mask=attention_mask,
511520
):
512521
seqlen = Seqlen(
513522
input_lengths=input_lengths,

0 commit comments

Comments
 (0)