Skip to content

Commit 9f635ea

Browse files
authored
[Fix] Address remaining issues of supporting MiniCPMV (sgl-project#2977)
1 parent 76285fd commit 9f635ea

File tree

12 files changed

+708
-223
lines changed

12 files changed

+708
-223
lines changed

docs/references/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ Another valuable resource is the [vLLM Models Directory](https://github.com/vllm
7878
To port a model from vLLM to SGLang, you can compare these two files [SGLang Llama Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py) and [vLLM Llama Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of Attention with RadixAttention. The other parts are almost identical. Specifically,
7979
- Replace vllm's `Attention` with `RadixAttention`. Note that you need to pass `layer_id` all the way to `RadixAttention`.
8080
- Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`.
81+
- Replace Multi-headed `Attention` of ViT with SGLang's `VisionAttention`.
8182
- Replace other vLLM layers with SGLang layers (e.g., `RMSNorm`, `SiluAndMul`).
8283
- Remove `Sample`.
8384
- Change `forward()` functions, and add `forward_batch`.

python/sglang/srt/layers/attention/triton_ops/prefill_attention.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,12 @@ def _fwd_kernel(
166166
def context_attention_fwd(
167167
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
168168
):
169+
"""
170+
q, k, v: [b * s, head, head_dim]
171+
b_start_loc: [b]
172+
b_seq_len: [b]
173+
out: [b * s, head, head_dim]
174+
"""
169175
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
170176
BLOCK = 128
171177
else:

python/sglang/srt/layers/attention/vision.py

Lines changed: 243 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66
import torch.nn as nn
7+
import torch.nn.functional as F
78
from einops import rearrange, repeat
89

910
from sglang.srt.distributed import parallel_state
@@ -63,7 +64,20 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.T
6364

6465

6566
class VisionAttention(nn.Module):
66-
"""Multi-headed attention without any cache, mostly used for ViT."""
67+
r"""
68+
Multi-headed attention without any cache, mostly used for ViT.
69+
70+
71+
Args:
72+
use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
73+
use_context_forward (bool, default to True):
74+
if ``True``, a flash_attn style attention will be applied
75+
Otherwise, a full-sequence attention will be applied.
76+
use_full_precision_softmax (bool, default to False):
77+
if ``True``, the softmax will be performed in full-precision
78+
Otherwise, it will be performed in half-precision
79+
80+
"""
6781

6882
def __init__(
6983
self,
@@ -72,25 +86,39 @@ def __init__(
7286
projection_size: int,
7387
use_qkv_parallel: bool,
7488
quant_config: Optional[QuantizationConfig] = None,
89+
dropout: float = 0.0,
90+
use_context_forward: bool = True,
91+
use_full_precision_softmax: bool = False,
92+
flatten_batch: bool = False,
7593
prefix: str = "",
7694
):
7795
super().__init__()
96+
self.use_context_forward = use_context_forward
7897
world_size = parallel_state.get_tensor_model_parallel_world_size()
79-
98+
self.dropout = dropout
99+
self.head_size = embed_dim // num_heads
80100
self.hidden_size_per_attention_head = dist_utils.divide(
81101
projection_size, num_heads
82102
)
83103
self.num_attention_heads_per_partition = dist_utils.divide(
84104
num_heads, world_size
85105
)
86-
# self.tp_size = get_tensor_model_parallel_world_size()
87-
# num_heads = self.num_heads_per_partition
106+
107+
if self.use_context_forward:
108+
self.qkv_backend = VisionTritonAttention()
109+
else:
110+
self.qkv_backend = VisionSdpaAttention(
111+
head_size=self.head_size,
112+
dropout=dropout,
113+
flatten_batch=flatten_batch,
114+
use_full_precision_softmax=use_full_precision_softmax,
115+
)
116+
88117
self.use_qkv_parallel = use_qkv_parallel
89118
if use_qkv_parallel:
90-
self.head_dim = embed_dim // num_heads
91119
self.qkv_proj = QKVParallelLinear(
92120
hidden_size=embed_dim,
93-
head_size=self.head_dim,
121+
head_size=self.head_size,
94122
total_num_heads=num_heads,
95123
quant_config=quant_config,
96124
prefix=f"{prefix}.qkv_proj",
@@ -114,12 +142,15 @@ def forward(
114142
x: torch.Tensor,
115143
cu_seqlens: Optional[torch.Tensor] = None,
116144
rotary_pos_emb: torch.Tensor = None,
145+
attention_mask: Optional[torch.Tensor] = None,
117146
) -> torch.Tensor:
147+
r"""
148+
Args:
149+
x: [b, s, embed_dim]
150+
cu_seqlens: [b]
151+
Returns:
152+
[s, b, num_heads * head]
118153
"""
119-
Input shape: [b, s, embed_dim]
120-
Output shape: [s, b, num_heads * head_size]
121-
"""
122-
123154
bsz, s, _ = x.shape
124155
if self.use_qkv_parallel:
125156
# [b, s, embed_dim] --> [b, s, embed_dim]
@@ -136,19 +167,19 @@ def forward(
136167
else:
137168
# [b, s, embed_dim] --> [s, b, embed_dim]
138169
x = rearrange(x, "b s ... -> s b ...")
139-
# [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
170+
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
140171
qkv, _ = self.qkv_proj(x)
141-
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
172+
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
142173
new_x_shape = qkv.size()[:-1] + (
143174
self.num_attention_heads_per_partition,
144175
3 * self.hidden_size_per_attention_head,
145176
)
146177
qkv = qkv.view(*new_x_shape)
147178

148-
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
179+
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
149180
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
150181

151-
# [s, b, head, head_dim] --> [b, s, head, head_dim]
182+
# [s, b, head, head_size] --> [b, s, head, head_size]
152183
q, k, v = [
153184
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
154185
]
@@ -160,45 +191,217 @@ def forward(
160191
if self.use_qkv_parallel:
161192
pass
162193
else:
163-
# [b, s, head, head_dim] --> [b * s, head, head_dim]
194+
# [b, s, head, head_size] --> [b * s, head, head_size]
164195
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
165196

166-
# [b * s, num_heads, head_size]
167-
output = torch.empty_like(q)
168-
169-
seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
170-
max_seqlen = seq_lens.max().item()
171-
172-
context_attention_fwd(
173-
q,
174-
k,
175-
v,
176-
output,
177-
cu_seqlens.cuda(),
178-
seq_lens,
179-
max_seqlen,
180-
is_causal=False,
181-
)
197+
output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask)
182198

183199
if self.use_qkv_parallel:
184-
185-
# [b * s, head, head_dim] --> [b, s, head * head_dim]
200+
# [b * s, h, head_size] --> [b, s, h * head_size]
186201
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
187202

188-
# [b, s, head, head_dim] --> [b, s, head, head_dim]
203+
# [b, s, h * head_size] --> [b, s, h * head_size]
189204
output, _ = self.proj(output)
190205
else:
191-
# [b * s, head, head_dim] --> [b, s, head, head_dim]
192-
context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
193-
194-
# [s, b, num_heads * head_size]
206+
# [b * s, h, head_size] --> [s, b, h * head_size]
195207
context_layer = rearrange(
196-
context_layer, "b s h d -> s b (h d)"
208+
output, "(b s) h d -> s b (h d)", b=bsz, s=s
197209
).contiguous()
198210

199-
# [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
211+
# [s, b, h * head_size] --> [s, b, h * head_size]
200212
output, _ = self.proj(context_layer)
201213

214+
# [s, b, h * head_size] --> [b, s, h * head_size]
202215
output = output.view(bsz, s, -1)
203216

204217
return output
218+
219+
220+
class VisionSdpaAttention(nn.Module):
221+
r"""
222+
Scaled Dot Product Attention inner product
223+
224+
"""
225+
226+
# TODO: Should it be released after used?
227+
_mask_cache = {}
228+
229+
def __init__(
230+
self,
231+
head_size: int,
232+
dropout: float = 0.0,
233+
flatten_batch: bool = False,
234+
use_full_precision_softmax: bool = False,
235+
):
236+
super().__init__()
237+
self.head_size = head_size
238+
self.flatten_batch = flatten_batch
239+
self.use_full_precision_softmax = use_full_precision_softmax
240+
self.dropout = dropout
241+
242+
def generate_patch_attention_mask(
243+
self,
244+
s: int,
245+
bsz: int,
246+
device,
247+
cu_seqlens: Optional[torch.Tensor],
248+
flatten_batch: bool = False,
249+
dtype=torch.bfloat16,
250+
) -> torch.Tensor:
251+
r"""
252+
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
253+
254+
When `flatten_batch` is True:
255+
- All sequences in the batch are flattened into a single dimension
256+
- `s` represents the total number of tokens across all sequences in the batch
257+
- Returns a unified mask of shape `(1, 1, s, s)`
258+
259+
When `flatten_batch` is False:
260+
- Each sequence has its own attention mask
261+
- `s` represents the maximum sequence length in the batch
262+
- Returns separate masks of shape `(b, 1, s, s)`
263+
264+
Args:
265+
flatten_batch: (bool):
266+
If True, treats all sequences in the batch as a single flattened sequence
267+
If False, generates separate masks for each sequence
268+
269+
Returns:
270+
Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
271+
"""
272+
273+
cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist()))
274+
275+
if cache_key in VisionSdpaAttention._mask_cache:
276+
cached_mask = VisionSdpaAttention._mask_cache[cache_key]
277+
# print(f"cache hit for key: {cache_key}")
278+
return cached_mask.to(device=device, dtype=dtype)
279+
280+
if cu_seqlens is None:
281+
raise ValueError("Internal Error: cu_seqlens cannot be None")
282+
283+
if flatten_batch:
284+
mask = torch.zeros([1, s, s], device=device, dtype=torch.bool)
285+
for i in range(1, len(cu_seqlens)):
286+
start = cu_seqlens[i - 1]
287+
end = cu_seqlens[i]
288+
mask[
289+
...,
290+
start:end,
291+
start:end,
292+
] = True
293+
else:
294+
# [1, 1, 1, s]
295+
row_indices = torch.arange(s, device=device).view(1, 1, 1, s)
296+
# [1, 1, s, 1]
297+
col_indices = torch.arange(s, device=device).view(1, 1, s, 1)
298+
# [b, 1, 1, 1]
299+
seq_lens = (
300+
(cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1)
301+
)
302+
303+
mask = (row_indices < seq_lens) & (col_indices < seq_lens)
304+
305+
# Convert to attention mask format (False -> 0, True -> -inf)
306+
mask = (~mask).to(dtype) * torch.finfo(dtype).min
307+
308+
VisionSdpaAttention._mask_cache[cache_key] = mask
309+
310+
return mask
311+
312+
def forward(
313+
self,
314+
q: torch.Tensor,
315+
k: torch.Tensor,
316+
v: torch.Tensor,
317+
bsz: int,
318+
cu_seqlens: Optional[torch.Tensor] = None,
319+
attention_mask: Optional[torch.Tensor] = None,
320+
) -> torch.Tensor:
321+
r"""
322+
Args:
323+
cu_seqlens: [b]
324+
Returns:
325+
[b * s, h, head_size]
326+
"""
327+
328+
s = q.shape[0] // bsz
329+
330+
# [b, 1, s, s]
331+
if attention_mask is None:
332+
attention_mask = self.generate_patch_attention_mask(
333+
s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype
334+
)
335+
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
336+
# [b, 1, s]
337+
if self.use_full_precision_softmax:
338+
scale = self.head_size**-0.5
339+
k_transposed = rearrange(k, "b h s d -> b h d s")
340+
attn_weights = torch.matmul(q, k_transposed) * scale
341+
del k, k_transposed
342+
attn_weights = attn_weights + attention_mask
343+
del attention_mask
344+
# full-precision
345+
attn_weights = nn.functional.softmax(
346+
attn_weights, dim=-1, dtype=torch.float32
347+
).to(q.dtype)
348+
attn_weights = nn.functional.dropout(
349+
attn_weights, p=self.dropout, training=False
350+
)
351+
output = torch.matmul(attn_weights, v)
352+
del attn_weights, v
353+
else:
354+
# SDPA
355+
# [b, h, s, head_size]
356+
output = F.scaled_dot_product_attention(
357+
q, k, v, attention_mask, dropout_p=self.dropout
358+
)
359+
360+
# [b, h, s, head_size] --> [b * s, h, head_size]
361+
output = rearrange(output, "b h s d -> (b s) h d")
362+
363+
return output
364+
365+
366+
class VisionTritonAttention(nn.Module):
367+
"""
368+
Triton-implemented attention without a causal mask
369+
"""
370+
371+
def __init__(
372+
self,
373+
):
374+
super().__init__()
375+
376+
def forward(
377+
self,
378+
q: torch.Tensor,
379+
k: torch.Tensor,
380+
v: torch.Tensor,
381+
_bsz: int,
382+
cu_seqlens: Optional[torch.Tensor],
383+
**kwargs,
384+
) -> torch.Tensor:
385+
r"""
386+
Args:
387+
cu_seqlens: [b]
388+
Returns:
389+
[b * s, h, head_size]
390+
"""
391+
392+
# [b * s, head, head_size]
393+
output = torch.empty_like(q)
394+
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
395+
max_seqlen = seq_lens.max().item()
396+
context_attention_fwd(
397+
q,
398+
k,
399+
v,
400+
output,
401+
cu_seqlens.cuda(),
402+
seq_lens.cuda(),
403+
max_seqlen,
404+
is_causal=False,
405+
)
406+
407+
return output

0 commit comments

Comments
 (0)