4
4
5
5
import torch
6
6
import torch .nn as nn
7
+ import torch .nn .functional as F
7
8
from einops import rearrange , repeat
8
9
9
10
from sglang .srt .distributed import parallel_state
@@ -63,7 +64,20 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.T
63
64
64
65
65
66
class VisionAttention (nn .Module ):
66
- """Multi-headed attention without any cache, mostly used for ViT."""
67
+ r"""
68
+ Multi-headed attention without any cache, mostly used for ViT.
69
+
70
+
71
+ Args:
72
+ use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
73
+ use_context_forward (bool, default to True):
74
+ if ``True``, a flash_attn style attention will be applied
75
+ Otherwise, a full-sequence attention will be applied.
76
+ use_full_precision_softmax (bool, default to False):
77
+ if ``True``, the softmax will be performed in full-precision
78
+ Otherwise, it will be performed in half-precision
79
+
80
+ """
67
81
68
82
def __init__ (
69
83
self ,
@@ -72,25 +86,39 @@ def __init__(
72
86
projection_size : int ,
73
87
use_qkv_parallel : bool ,
74
88
quant_config : Optional [QuantizationConfig ] = None ,
89
+ dropout : float = 0.0 ,
90
+ use_context_forward : bool = True ,
91
+ use_full_precision_softmax : bool = False ,
92
+ flatten_batch : bool = False ,
75
93
prefix : str = "" ,
76
94
):
77
95
super ().__init__ ()
96
+ self .use_context_forward = use_context_forward
78
97
world_size = parallel_state .get_tensor_model_parallel_world_size ()
79
-
98
+ self .dropout = dropout
99
+ self .head_size = embed_dim // num_heads
80
100
self .hidden_size_per_attention_head = dist_utils .divide (
81
101
projection_size , num_heads
82
102
)
83
103
self .num_attention_heads_per_partition = dist_utils .divide (
84
104
num_heads , world_size
85
105
)
86
- # self.tp_size = get_tensor_model_parallel_world_size()
87
- # num_heads = self.num_heads_per_partition
106
+
107
+ if self .use_context_forward :
108
+ self .qkv_backend = VisionTritonAttention ()
109
+ else :
110
+ self .qkv_backend = VisionSdpaAttention (
111
+ head_size = self .head_size ,
112
+ dropout = dropout ,
113
+ flatten_batch = flatten_batch ,
114
+ use_full_precision_softmax = use_full_precision_softmax ,
115
+ )
116
+
88
117
self .use_qkv_parallel = use_qkv_parallel
89
118
if use_qkv_parallel :
90
- self .head_dim = embed_dim // num_heads
91
119
self .qkv_proj = QKVParallelLinear (
92
120
hidden_size = embed_dim ,
93
- head_size = self .head_dim ,
121
+ head_size = self .head_size ,
94
122
total_num_heads = num_heads ,
95
123
quant_config = quant_config ,
96
124
prefix = f"{ prefix } .qkv_proj" ,
@@ -114,12 +142,15 @@ def forward(
114
142
x : torch .Tensor ,
115
143
cu_seqlens : Optional [torch .Tensor ] = None ,
116
144
rotary_pos_emb : torch .Tensor = None ,
145
+ attention_mask : Optional [torch .Tensor ] = None ,
117
146
) -> torch .Tensor :
147
+ r"""
148
+ Args:
149
+ x: [b, s, embed_dim]
150
+ cu_seqlens: [b]
151
+ Returns:
152
+ [s, b, num_heads * head]
118
153
"""
119
- Input shape: [b, s, embed_dim]
120
- Output shape: [s, b, num_heads * head_size]
121
- """
122
-
123
154
bsz , s , _ = x .shape
124
155
if self .use_qkv_parallel :
125
156
# [b, s, embed_dim] --> [b, s, embed_dim]
@@ -136,19 +167,19 @@ def forward(
136
167
else :
137
168
# [b, s, embed_dim] --> [s, b, embed_dim]
138
169
x = rearrange (x , "b s ... -> s b ..." )
139
- # [s, b, embed_dim] --> [s, b, head * 3 * head_dim ]
170
+ # [s, b, embed_dim] --> [s, b, head * 3 * head_size ]
140
171
qkv , _ = self .qkv_proj (x )
141
- # [s, b, head * 3 * head_dim ] --> [s, b, head, 3 * head_dim ]
172
+ # [s, b, head * 3 * head_size ] --> [s, b, head, 3 * head_size ]
142
173
new_x_shape = qkv .size ()[:- 1 ] + (
143
174
self .num_attention_heads_per_partition ,
144
175
3 * self .hidden_size_per_attention_head ,
145
176
)
146
177
qkv = qkv .view (* new_x_shape )
147
178
148
- # [s, b, head, 3 * head_dim ] --> 3 [s, b, head, head_dim ]
179
+ # [s, b, head, 3 * head_size ] --> 3 [s, b, head, head_size ]
149
180
q , k , v = dist_utils .split_tensor_along_last_dim (qkv , 3 )
150
181
151
- # [s, b, head, head_dim ] --> [b, s, head, head_dim ]
182
+ # [s, b, head, head_size ] --> [b, s, head, head_size ]
152
183
q , k , v = [
153
184
rearrange (x , "s b ... -> b s ..." ).contiguous () for x in (q , k , v )
154
185
]
@@ -160,45 +191,217 @@ def forward(
160
191
if self .use_qkv_parallel :
161
192
pass
162
193
else :
163
- # [b, s, head, head_dim ] --> [b * s, head, head_dim ]
194
+ # [b, s, head, head_size ] --> [b * s, head, head_size ]
164
195
q , k , v = [rearrange (x , "b s ... -> (b s) ..." ) for x in [q , k , v ]]
165
196
166
- # [b * s, num_heads, head_size]
167
- output = torch .empty_like (q )
168
-
169
- seq_lens = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).cuda ()
170
- max_seqlen = seq_lens .max ().item ()
171
-
172
- context_attention_fwd (
173
- q ,
174
- k ,
175
- v ,
176
- output ,
177
- cu_seqlens .cuda (),
178
- seq_lens ,
179
- max_seqlen ,
180
- is_causal = False ,
181
- )
197
+ output = self .qkv_backend .forward (q , k , v , bsz , cu_seqlens , attention_mask )
182
198
183
199
if self .use_qkv_parallel :
184
-
185
- # [b * s, head, head_dim] --> [b, s, head * head_dim]
200
+ # [b * s, h, head_size] --> [b, s, h * head_size]
186
201
output = rearrange (output , "(b s) ... h d -> b s ... (h d)" , b = bsz )
187
202
188
- # [b, s, head, head_dim ] --> [b, s, head, head_dim ]
203
+ # [b, s, h * head_size ] --> [b, s, h * head_size ]
189
204
output , _ = self .proj (output )
190
205
else :
191
- # [b * s, head, head_dim] --> [b, s, head, head_dim]
192
- context_layer = rearrange (output , "(b s) ... -> b s ..." , b = bsz )
193
-
194
- # [s, b, num_heads * head_size]
206
+ # [b * s, h, head_size] --> [s, b, h * head_size]
195
207
context_layer = rearrange (
196
- context_layer , "b s h d -> s b (h d)"
208
+ output , "( b s) h d -> s b (h d)" , b = bsz , s = s
197
209
).contiguous ()
198
210
199
- # [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
211
+ # [s, b, h * head_size] --> [s, b, h * head_size]
200
212
output , _ = self .proj (context_layer )
201
213
214
+ # [s, b, h * head_size] --> [b, s, h * head_size]
202
215
output = output .view (bsz , s , - 1 )
203
216
204
217
return output
218
+
219
+
220
+ class VisionSdpaAttention (nn .Module ):
221
+ r"""
222
+ Scaled Dot Product Attention inner product
223
+
224
+ """
225
+
226
+ # TODO: Should it be released after used?
227
+ _mask_cache = {}
228
+
229
+ def __init__ (
230
+ self ,
231
+ head_size : int ,
232
+ dropout : float = 0.0 ,
233
+ flatten_batch : bool = False ,
234
+ use_full_precision_softmax : bool = False ,
235
+ ):
236
+ super ().__init__ ()
237
+ self .head_size = head_size
238
+ self .flatten_batch = flatten_batch
239
+ self .use_full_precision_softmax = use_full_precision_softmax
240
+ self .dropout = dropout
241
+
242
+ def generate_patch_attention_mask (
243
+ self ,
244
+ s : int ,
245
+ bsz : int ,
246
+ device ,
247
+ cu_seqlens : Optional [torch .Tensor ],
248
+ flatten_batch : bool = False ,
249
+ dtype = torch .bfloat16 ,
250
+ ) -> torch .Tensor :
251
+ r"""
252
+ Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
253
+
254
+ When `flatten_batch` is True:
255
+ - All sequences in the batch are flattened into a single dimension
256
+ - `s` represents the total number of tokens across all sequences in the batch
257
+ - Returns a unified mask of shape `(1, 1, s, s)`
258
+
259
+ When `flatten_batch` is False:
260
+ - Each sequence has its own attention mask
261
+ - `s` represents the maximum sequence length in the batch
262
+ - Returns separate masks of shape `(b, 1, s, s)`
263
+
264
+ Args:
265
+ flatten_batch: (bool):
266
+ If True, treats all sequences in the batch as a single flattened sequence
267
+ If False, generates separate masks for each sequence
268
+
269
+ Returns:
270
+ Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
271
+ """
272
+
273
+ cache_key = (s , bsz , flatten_batch , tuple (cu_seqlens .cpu ().tolist ()))
274
+
275
+ if cache_key in VisionSdpaAttention ._mask_cache :
276
+ cached_mask = VisionSdpaAttention ._mask_cache [cache_key ]
277
+ # print(f"cache hit for key: {cache_key}")
278
+ return cached_mask .to (device = device , dtype = dtype )
279
+
280
+ if cu_seqlens is None :
281
+ raise ValueError ("Internal Error: cu_seqlens cannot be None" )
282
+
283
+ if flatten_batch :
284
+ mask = torch .zeros ([1 , s , s ], device = device , dtype = torch .bool )
285
+ for i in range (1 , len (cu_seqlens )):
286
+ start = cu_seqlens [i - 1 ]
287
+ end = cu_seqlens [i ]
288
+ mask [
289
+ ...,
290
+ start :end ,
291
+ start :end ,
292
+ ] = True
293
+ else :
294
+ # [1, 1, 1, s]
295
+ row_indices = torch .arange (s , device = device ).view (1 , 1 , 1 , s )
296
+ # [1, 1, s, 1]
297
+ col_indices = torch .arange (s , device = device ).view (1 , 1 , s , 1 )
298
+ # [b, 1, 1, 1]
299
+ seq_lens = (
300
+ (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).to (device = device ).view (- 1 , 1 , 1 , 1 )
301
+ )
302
+
303
+ mask = (row_indices < seq_lens ) & (col_indices < seq_lens )
304
+
305
+ # Convert to attention mask format (False -> 0, True -> -inf)
306
+ mask = (~ mask ).to (dtype ) * torch .finfo (dtype ).min
307
+
308
+ VisionSdpaAttention ._mask_cache [cache_key ] = mask
309
+
310
+ return mask
311
+
312
+ def forward (
313
+ self ,
314
+ q : torch .Tensor ,
315
+ k : torch .Tensor ,
316
+ v : torch .Tensor ,
317
+ bsz : int ,
318
+ cu_seqlens : Optional [torch .Tensor ] = None ,
319
+ attention_mask : Optional [torch .Tensor ] = None ,
320
+ ) -> torch .Tensor :
321
+ r"""
322
+ Args:
323
+ cu_seqlens: [b]
324
+ Returns:
325
+ [b * s, h, head_size]
326
+ """
327
+
328
+ s = q .shape [0 ] // bsz
329
+
330
+ # [b, 1, s, s]
331
+ if attention_mask is None :
332
+ attention_mask = self .generate_patch_attention_mask (
333
+ s , bsz , q .device , cu_seqlens , self .flatten_batch , q .dtype
334
+ )
335
+ q , k , v = [rearrange (x , "(b s) h d -> b h s d" , b = bsz ) for x in [q , k , v ]]
336
+ # [b, 1, s]
337
+ if self .use_full_precision_softmax :
338
+ scale = self .head_size ** - 0.5
339
+ k_transposed = rearrange (k , "b h s d -> b h d s" )
340
+ attn_weights = torch .matmul (q , k_transposed ) * scale
341
+ del k , k_transposed
342
+ attn_weights = attn_weights + attention_mask
343
+ del attention_mask
344
+ # full-precision
345
+ attn_weights = nn .functional .softmax (
346
+ attn_weights , dim = - 1 , dtype = torch .float32
347
+ ).to (q .dtype )
348
+ attn_weights = nn .functional .dropout (
349
+ attn_weights , p = self .dropout , training = False
350
+ )
351
+ output = torch .matmul (attn_weights , v )
352
+ del attn_weights , v
353
+ else :
354
+ # SDPA
355
+ # [b, h, s, head_size]
356
+ output = F .scaled_dot_product_attention (
357
+ q , k , v , attention_mask , dropout_p = self .dropout
358
+ )
359
+
360
+ # [b, h, s, head_size] --> [b * s, h, head_size]
361
+ output = rearrange (output , "b h s d -> (b s) h d" )
362
+
363
+ return output
364
+
365
+
366
+ class VisionTritonAttention (nn .Module ):
367
+ """
368
+ Triton-implemented attention without a causal mask
369
+ """
370
+
371
+ def __init__ (
372
+ self ,
373
+ ):
374
+ super ().__init__ ()
375
+
376
+ def forward (
377
+ self ,
378
+ q : torch .Tensor ,
379
+ k : torch .Tensor ,
380
+ v : torch .Tensor ,
381
+ _bsz : int ,
382
+ cu_seqlens : Optional [torch .Tensor ],
383
+ ** kwargs ,
384
+ ) -> torch .Tensor :
385
+ r"""
386
+ Args:
387
+ cu_seqlens: [b]
388
+ Returns:
389
+ [b * s, h, head_size]
390
+ """
391
+
392
+ # [b * s, head, head_size]
393
+ output = torch .empty_like (q )
394
+ seq_lens = cu_seqlens [1 :] - cu_seqlens [:- 1 ]
395
+ max_seqlen = seq_lens .max ().item ()
396
+ context_attention_fwd (
397
+ q ,
398
+ k ,
399
+ v ,
400
+ output ,
401
+ cu_seqlens .cuda (),
402
+ seq_lens .cuda (),
403
+ max_seqlen ,
404
+ is_causal = False ,
405
+ )
406
+
407
+ return output
0 commit comments