-
Notifications
You must be signed in to change notification settings - Fork 95
Expand file tree
/
Copy pathgqa.py
More file actions
704 lines (632 loc) · 27.5 KB
/
gqa.py
File metadata and controls
704 lines (632 loc) · 27.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/modules/attention/gqa.py
import enum
import logging
import torch
from neuronx_distributed.parallel_layers import parallel_state
from neuronx_distributed.parallel_layers.layers import ColumnParallelLinear, RowParallelLinear
from neuronx_distributed.parallel_layers.pad import get_number_of_extra_heads
from torch import nn
from torch.distributed import ProcessGroup
from torch.nn import functional as F
logger = logging.getLogger("Neuron")
class GQA(enum.Enum):
# This transforms a GQA attention mechanism into a traditional MHA mechanism
# by replicating the K/V heads to evenly match the corresponding Q heads.
# This consumes more memory than would otherwise be used with other sharding
# mechanisms but works in all cases.
# Example:
# tp_degree = 32
# num_attention_heads: 56 -> 64
# num_kev_value_heads: 8 -> 64
# | K1 K1 | K2 K2 | ... | K7 K7| Pad Pad | ... | Pad Pad |
# | Q1 Q2 | Q3 Q4 | ... | Q55 Q56 | Pad Pad | ... | Pad Pad |
CONVERT_TO_MHA = "convert-to-mha"
# This transforms a GQA attention mechanism such that there is exactly
# one K/V head per tp_degree through replication e.g. 8 K/V heads with
# tp_degree=32 results in 32 K/V heads. This is more memory efficient but
# does not work for all configurations. Q heads are padded interleaved
# to retain correct alignment between Q and K/V heads.
# Example:
# tp_degree = 32
# num_attention_heads: 56 -> 64
# num_kev_value_heads: 8 -> 32
# | K1 | K1 | K1 | K1 | K2 | ...
# | Q1 Q2 | Q3 Q4 | Q5 Q6 | Q7 Pad | Q8 Q9 | ...
REPLICATE_TO_TP_DEGREE = "replicate-to-tp-degree"
def get_shardable_head_counts(
tp_degree: int, num_attention_heads: int, num_key_value_heads: int
) -> tuple[GQA, int, int]:
sharding_strategy = GQA.REPLICATE_TO_TP_DEGREE
if tp_degree % num_key_value_heads != 0:
sharding_strategy = GQA.CONVERT_TO_MHA
# Pad attention heads
updated_num_attention_heads = num_attention_heads + get_number_of_extra_heads(num_attention_heads, tp_degree)
# Replicate and pad K/V heads
updated_num_key_value_heads = num_key_value_heads
if num_attention_heads == num_key_value_heads: # MHA
updated_num_key_value_heads = updated_num_attention_heads
else: # GQA / MQA
if (num_key_value_heads < tp_degree) or (num_key_value_heads % tp_degree != 0):
if sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE:
assert tp_degree % num_key_value_heads == 0, (
"GQA.REPLICATE_TO_TP_DEGREE requires tp_degree to be divisible by num_key_value_heads"
)
updated_num_key_value_heads = tp_degree
elif sharding_strategy == GQA.CONVERT_TO_MHA:
updated_num_key_value_heads = updated_num_attention_heads
return sharding_strategy, updated_num_attention_heads, updated_num_key_value_heads
def maybe_pad_interleaved(tensor, pad_dim: int, source_heads: int, target_heads: int, source_group_size: int):
if tensor is None:
return tensor
# Why we convert FP8 tensor to bfloat16?
# Torch does not support torch.cat, or torch.zeros (for large dimensions) for f8e4m3/f8e5m2
# So we cast it to bfloat16, perform padding, and then recast back to f8e4m3/f8e5m2
recast_dtype = None
if tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
recast_dtype = tensor.dtype
tensor = tensor.to(torch.bfloat16)
shape = (
tensor.shape[:pad_dim] + (source_heads, tensor.shape[pad_dim] // source_heads) + tensor.shape[pad_dim + 1 :]
)
tensor = tensor.view(shape)
splits = torch.split(tensor, source_group_size, dim=pad_dim)
pad_size = list(splits[0].size())
pad_size[pad_dim] = (target_heads - source_heads) // (source_heads // source_group_size)
pads = [torch.zeros(pad_size, dtype=tensor.dtype)] * len(splits)
interleaved = [t for pair in zip(splits, pads) for t in pair]
tensor = torch.cat(interleaved, dim=pad_dim)
shape = tensor.shape[:pad_dim] + (tensor.shape[pad_dim] * tensor.shape[pad_dim + 1],) + tensor.shape[pad_dim + 2 :]
if recast_dtype is not None:
tensor = tensor.to(recast_dtype)
return tensor.view(shape)
def maybe_pad_tail(tensor, source_heads: int, target_heads: int, pad_dim: int):
if tensor is None:
return tensor
size_to_pad = int((tensor.shape[pad_dim] // source_heads) * target_heads - tensor.shape[pad_dim])
dims_after_pad_dim = len(tensor.size()) - pad_dim
pad_length = dims_after_pad_dim * 2
pad = (0,) * (pad_length - 1) + (size_to_pad,)
return F.pad(tensor, pad)
def replicate_kv(tensor, source_heads: int, repeats: int, head_dim=0):
if tensor is None:
return tensor
shape = (
tensor.shape[:head_dim] + (source_heads, tensor.shape[head_dim] // source_heads) + tensor.shape[head_dim + 1 :]
)
tensor = tensor.view(shape)
tensor = torch.repeat_interleave(tensor, repeats=repeats, dim=head_dim)
shape = (
tensor.shape[:head_dim] + (tensor.shape[head_dim] * tensor.shape[head_dim + 1],) + tensor.shape[head_dim + 2 :]
)
return tensor.view(shape)
class BaseGroupQueryAttention(nn.Module):
def __init__(
self,
hidden_size: int,
head_dim: int,
num_attention_heads: int,
num_key_value_heads: int,
tp_degree: int = 1,
dtype: torch.dtype = torch.float32,
bias: bool = False,
tensor_model_parallel_group: ProcessGroup | None = None,
):
super().__init__()
if tensor_model_parallel_group is not None:
self.tensor_model_parallel_group = tensor_model_parallel_group
else:
self.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group()
if tensor_model_parallel_group:
if tp_degree == 1:
# update default value
tp_degree = tensor_model_parallel_group.size()
else:
assert tp_degree == self.tensor_model_parallel_group.size(), (
f"TP Degree {tp_degree} and tensor model parallel group size {self.tensor_model_parallel_group.size()} does not match"
)
self.hidden_size = hidden_size
self.tp_degree = tp_degree
self.head_dim = head_dim
self.dtype = dtype
self.bias = bias
self._src_num_attention_heads = num_attention_heads
self._src_num_key_value_heads = num_key_value_heads
self.sharding_strategy, self.num_attention_heads, self.num_key_value_heads = get_shardable_head_counts(
tp_degree,
self._src_num_attention_heads,
self._src_num_key_value_heads,
)
def get_sharding_strategy(self) -> GQA:
return self.sharding_strategy
def get_num_attention_heads(self) -> int:
return self.num_attention_heads
def get_num_key_value_heads(self) -> int:
return self.num_key_value_heads
def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool:
raise NotImplementedError
def replace_prefixes(self, old_prefix, new_prefix, model_state_dict):
old_keys = []
new_keys = []
for key in model_state_dict.keys():
if old_prefix in key:
new_key = key.replace(old_prefix, new_prefix)
new_keys.append(new_key)
old_keys.append(key)
for key_index in range(len(old_keys)):
model_state_dict[new_keys[key_index]] = model_state_dict.pop(old_keys[key_index])
class GroupQueryAttention_QKV(BaseGroupQueryAttention):
def __init__(
self,
hidden_size: int,
head_dim: int,
num_attention_heads: int,
num_key_value_heads: int,
tp_degree: int = 1,
dtype: torch.dtype = torch.float32,
bias: bool = False,
gather_output: bool = True,
fused_qkv: bool = False,
clip_qkv: float | None = None,
tensor_model_parallel_group: ProcessGroup | None = None,
rms_norm_eps: float = None,
logical_nc_config: int = 1,
):
super().__init__(
hidden_size=hidden_size,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
tp_degree=tp_degree,
dtype=dtype,
bias=bias,
tensor_model_parallel_group=tensor_model_parallel_group,
)
if fused_qkv and gather_output:
raise ValueError(
"Gathering states followed by fused qkv is not allowed as it has a different weight sharding scheme."
)
self.gather_output = gather_output
self.fused_qkv = fused_qkv
self.clip_qkv = clip_qkv
self.rms_norm_eps = rms_norm_eps
self.logical_nc_config = logical_nc_config
if self.tensor_model_parallel_group is not None:
if self.fused_qkv:
self.Wqkv = ColumnParallelLinear(
self.hidden_size,
(self.num_attention_heads + 2 * self.num_key_value_heads) * self.head_dim,
bias=self.bias,
gather_output=self.gather_output,
dtype=dtype,
tensor_model_parallel_group=self.tensor_model_parallel_group,
)
# Set heads info as weight parameter attributes to be used in weights sharding
setattr(self.Wqkv.weight, "fused_qkv", True)
setattr(self.Wqkv.weight, "num_attention_heads", self.num_attention_heads)
setattr(self.Wqkv.weight, "num_key_value_heads", self.num_key_value_heads)
setattr(self.Wqkv.weight, "head_dim", self.head_dim)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_attention_heads * self.head_dim,
bias=self.bias,
gather_output=self.gather_output,
dtype=dtype,
tensor_model_parallel_group=self.tensor_model_parallel_group,
)
self.k_proj = ColumnParallelLinear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=self.bias,
gather_output=self.gather_output,
dtype=dtype,
tensor_model_parallel_group=self.tensor_model_parallel_group,
)
self.v_proj = ColumnParallelLinear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=self.bias,
gather_output=self.gather_output,
dtype=dtype,
tensor_model_parallel_group=self.tensor_model_parallel_group,
)
else:
if self.fused_qkv:
self.Wqkv = nn.Linear(
self.hidden_size,
(self.num_attention_heads + 2 * self.num_key_value_heads) * self.head_dim,
bias=self.bias,
)
else:
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=self.bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.bias)
def forward(self, hidden_states: torch.Tensor):
if self.fused_qkv:
logger.debug("QKV: native compiler")
QKV = self.Wqkv(hidden_states)
return self._split_fused_qkv(QKV)
else:
Q = self.q_proj(hidden_states)
K = self.k_proj(hidden_states)
V = self.v_proj(hidden_states)
if self.clip_qkv is not None:
Q = Q.clamp(min=-self.clip_qkv, max=self.clip_qkv)
K = K.clamp(min=-self.clip_qkv, max=self.clip_qkv)
V = V.clamp(min=-self.clip_qkv, max=self.clip_qkv)
return Q, K, V
def _split_fused_qkv(self, QKV):
logger.debug(f"Fused QKV tensor has shape {QKV.shape}")
if self.clip_qkv is not None:
QKV = QKV.clamp(min=-self.clip_qkv, max=self.clip_qkv)
# shape of QKV is [batch, seqlen, fused_qkv_size]
# we split the fused QKV (dim=2) into Q, K, V
# for example:
# for 405B, TP=128, num_att_heads=128
# LNC=2/TP=64 will split QKV from [batch, seqlen, 512] into:
# Q [batch, seqlen, 256]
# K [batch, seqlen, 128]
# V [batch, seqlen, 128]
# torch.split has accuracy issue and leads to more reshapes in hlo.
# Using torch.tensor_split here. NAPP-3145
q_end_index = self.num_attention_heads * self.head_dim // self.tp_degree
k_end_index = q_end_index + self.num_key_value_heads * self.head_dim // self.tp_degree
Q, K, V = torch.tensor_split(
QKV,
(
q_end_index,
k_end_index,
# rest of the QKV will go to V output
),
dim=2,
)
logger.debug(f"QKV shape before tensor_split: {QKV.shape}")
logger.debug(f"Q shape after tensor_split: {Q.shape}")
logger.debug(f"K shape after tensor_split: {K.shape}")
logger.debug(f"V shape after tensor_split: {V.shape}")
return Q, K, V
def get_weight(
self, prefix: str, layer: torch.nn.Module, layer_name, model_state_dict: dict
) -> tuple[torch.Tensor]:
if hasattr(layer, "get_weight_from_state_dict"):
return layer.get_weight_from_state_dict(prefix=f"{prefix}.{layer_name}.", state_dict=model_state_dict)
return model_state_dict[f"{prefix}.{layer_name}.weight"]
def get_bias(
self, prefix: str, layer: torch.nn.Module, layer_name: str, model_state_dict: dict
) -> tuple[torch.Tensor]:
if hasattr(layer, "get_bias_from_state_dict"):
return layer.get_bias_from_state_dict(prefix=f"{prefix}.{layer_name}.", state_dict=model_state_dict)
return model_state_dict.get(f"{prefix}.{layer_name}.bias")
def set_weight(
self,
tensor: torch.Tensor,
prefix: str,
layer: torch.nn.Module,
layer_name,
model_state_dict: dict,
) -> tuple[torch.Tensor]:
# TODO: set weight to state dict support is pending.
model_state_dict[f"{prefix}.{layer_name}.weight"] = tensor
def set_bias(
self,
tensor: torch.Tensor,
prefix: str,
layer: torch.nn.Module,
layer_name: str,
model_state_dict: dict,
) -> tuple[torch.Tensor]:
if hasattr(layer, "set_bias_to_state_dict"):
layer.set_bias_to_state_dict(prefix=f"{prefix}.{layer_name}.", tensor=tensor, state_dict=model_state_dict)
else:
model_state_dict[f"{prefix}.{layer_name}.bias"] = tensor
def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool:
prefix_parts = prefix.split(".")
prefix = ".".join(prefix_parts[:-1])
hf_prefix = ".".join(prefix_parts[:-2])
if self.fused_qkv:
self.replace_prefixes(
old_prefix=f"{hf_prefix}.Wqkv",
new_prefix=f"{prefix}.Wqkv",
model_state_dict=model_state_dict,
)
qkv_weight = self.get_weight(
prefix=prefix, layer=self.Wqkv, layer_name="Wqkv", model_state_dict=model_state_dict
)
q_proj_weight, k_proj_weight, v_proj_weight = qkv_weight.split(
[
self._src_num_attention_heads * self.head_dim,
self._src_num_key_value_heads * self.head_dim,
self._src_num_key_value_heads * self.head_dim,
],
dim=0,
)
qkv_bias = self.get_bias(
prefix=prefix, layer=self.Wqkv, layer_name="Wqkv", model_state_dict=model_state_dict
)
if qkv_bias is not None:
q_proj_bias, k_proj_bias, v_proj_bias = qkv_bias.split(
[
self._src_num_attention_heads * self.head_dim,
self._src_num_key_value_heads * self.head_dim,
self._src_num_key_value_heads * self.head_dim,
],
dim=0,
)
else:
q_proj_bias, k_proj_bias, v_proj_bias = None, None, None
else:
self.replace_prefixes(
old_prefix=f"{hf_prefix}.q_proj",
new_prefix=f"{prefix}.q_proj",
model_state_dict=model_state_dict,
)
self.replace_prefixes(
old_prefix=f"{hf_prefix}.k_proj",
new_prefix=f"{prefix}.k_proj",
model_state_dict=model_state_dict,
)
self.replace_prefixes(
old_prefix=f"{hf_prefix}.v_proj",
new_prefix=f"{prefix}.v_proj",
model_state_dict=model_state_dict,
)
q_proj_weight = self.get_weight(
prefix=prefix,
layer=self.q_proj,
layer_name="q_proj",
model_state_dict=model_state_dict,
)
k_proj_weight = self.get_weight(
prefix=prefix,
layer=self.k_proj,
layer_name="k_proj",
model_state_dict=model_state_dict,
)
v_proj_weight = self.get_weight(
prefix=prefix,
layer=self.v_proj,
layer_name="v_proj",
model_state_dict=model_state_dict,
)
q_proj_bias = self.get_bias(
prefix=prefix,
layer=self.q_proj,
layer_name="q_proj",
model_state_dict=model_state_dict,
)
k_proj_bias = self.get_bias(
prefix=prefix,
layer=self.k_proj,
layer_name="k_proj",
model_state_dict=model_state_dict,
)
v_proj_bias = self.get_bias(
prefix=prefix,
layer=self.v_proj,
layer_name="v_proj",
model_state_dict=model_state_dict,
)
if self.num_key_value_heads != self._src_num_key_value_heads:
if self.sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE:
repeats = self.tp_degree // self._src_num_key_value_heads
elif self.sharding_strategy == GQA.CONVERT_TO_MHA:
repeats = self._src_num_attention_heads // self._src_num_key_value_heads
k_proj_weight = replicate_kv(
k_proj_weight,
source_heads=self._src_num_key_value_heads,
repeats=repeats,
head_dim=0,
)
k_proj_bias = replicate_kv(
k_proj_bias, source_heads=self._src_num_key_value_heads, repeats=repeats, head_dim=0
)
v_proj_weight = replicate_kv(
v_proj_weight,
source_heads=self._src_num_key_value_heads,
repeats=repeats,
head_dim=0,
)
v_proj_bias = replicate_kv(
v_proj_bias, source_heads=self._src_num_key_value_heads, repeats=repeats, head_dim=0
)
if self.sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE:
q_proj_weight = maybe_pad_interleaved(
q_proj_weight,
pad_dim=0,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads,
)
q_proj_bias = maybe_pad_interleaved(
q_proj_bias,
pad_dim=0,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads,
)
if self.sharding_strategy == GQA.CONVERT_TO_MHA:
q_proj_weight = maybe_pad_tail(
q_proj_weight,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
pad_dim=0,
)
q_proj_bias = maybe_pad_tail(
q_proj_bias,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
pad_dim=0,
)
# After replicate_kv, k and v have _src_num_attention_heads heads,
# so use that as source_heads for padding (not _src_num_key_value_heads).
k_proj_weight = maybe_pad_tail(
k_proj_weight,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
pad_dim=0,
)
k_proj_bias = maybe_pad_tail(
k_proj_bias,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
pad_dim=0,
)
v_proj_weight = maybe_pad_tail(
v_proj_weight,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
pad_dim=0,
)
v_proj_bias = maybe_pad_tail(
v_proj_bias,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
pad_dim=0,
)
if self.fused_qkv:
qkv_weight = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=0)
self.set_weight(
tensor=qkv_weight,
prefix=prefix,
layer=self.Wqkv,
layer_name="Wqkv",
model_state_dict=model_state_dict,
)
if self.bias:
qkv_bias = torch.cat([q_proj_bias, k_proj_bias, v_proj_bias], dim=0)
self.set_bias(
tensor=qkv_bias,
prefix=prefix,
layer=self.Wqkv,
layer_name="Wqkv",
model_state_dict=model_state_dict,
)
else:
self.set_weight(
tensor=q_proj_weight,
prefix=prefix,
layer=self.q_proj,
layer_name="q_proj",
model_state_dict=model_state_dict,
)
self.set_weight(
tensor=k_proj_weight,
prefix=prefix,
layer=self.k_proj,
layer_name="k_proj",
model_state_dict=model_state_dict,
)
self.set_weight(
tensor=v_proj_weight,
prefix=prefix,
layer=self.v_proj,
layer_name="v_proj",
model_state_dict=model_state_dict,
)
if self.bias:
self.set_bias(
tensor=q_proj_bias,
prefix=prefix,
layer=self.q_proj,
layer_name="q_proj",
model_state_dict=model_state_dict,
)
self.set_bias(
tensor=k_proj_bias,
prefix=prefix,
layer=self.k_proj,
layer_name="k_proj",
model_state_dict=model_state_dict,
)
self.set_bias(
tensor=v_proj_bias,
prefix=prefix,
layer=self.v_proj,
layer_name="v_proj",
model_state_dict=model_state_dict,
)
return True
class GroupQueryAttention_O(BaseGroupQueryAttention):
def __init__(
self,
hidden_size: int,
head_dim: int,
num_attention_heads: int,
num_key_value_heads: int,
tp_degree: int = 1,
dtype: torch.dtype = torch.float32,
bias: bool = False,
input_is_parallel: bool = False,
layer_name: str = "o_proj",
tensor_model_parallel_group: ProcessGroup | None = None,
rpl_reduce_dtype: torch.dtype = None,
):
super().__init__(
hidden_size=hidden_size,
head_dim=head_dim,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
tp_degree=tp_degree,
dtype=dtype,
bias=bias,
tensor_model_parallel_group=tensor_model_parallel_group,
)
self.input_is_parallel = input_is_parallel
if self.tensor_model_parallel_group is not None:
self.o_proj = RowParallelLinear(
self.num_attention_heads * self.head_dim,
self.hidden_size,
bias=self.bias,
input_is_parallel=self.input_is_parallel,
dtype=self.dtype,
tensor_model_parallel_group=self.tensor_model_parallel_group,
reduce_dtype=rpl_reduce_dtype,
)
else:
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=self.bias)
# Prepared for changing "o_proj" to the corresponding name in model_state_dict
# For example, in CLIP vision model, we use "out_proj"
self.layer_name = layer_name
def forward(self, attention_output: torch.Tensor):
return self.o_proj(attention_output)
def preshard_hook(self, model_state_dict: dict, prefix: str) -> bool:
prefix_parts = prefix.split(".")
prefix = ".".join(prefix_parts[:-1])
hf_prefix = ".".join(prefix_parts[:-2])
self.replace_prefixes(
old_prefix=f"{hf_prefix}.{self.layer_name}",
new_prefix=f"{prefix}.o_proj",
model_state_dict=model_state_dict,
)
o_proj_weight = model_state_dict[f"{prefix}.o_proj.weight"]
if self.sharding_strategy == GQA.REPLICATE_TO_TP_DEGREE:
o_proj_weight = maybe_pad_interleaved(
o_proj_weight,
pad_dim=1,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
source_group_size=self._src_num_attention_heads // self._src_num_key_value_heads,
)
if self.sharding_strategy == GQA.CONVERT_TO_MHA:
o_proj_weight = maybe_pad_tail(
o_proj_weight,
source_heads=self._src_num_attention_heads,
target_heads=self.num_attention_heads,
pad_dim=1,
)
model_state_dict[f"{prefix}.o_proj.weight"] = o_proj_weight
return True