@@ -165,60 +165,21 @@ def attention(self,
165
165
)
166
166
else :
167
167
hidden_states = self .unsqueeze (hidden_states , axis = 0 )
168
- if mode == "prefill" :
169
- query_states_to_concat = []
170
- key_states_to_concat = []
171
- value_states_to_concat = []
172
- for i in range (self .n_splits_linear ):
173
- sub_hidden_states = self .slice (hidden_states ,
174
- begin = [0 , 0 , i * groupsize ],
175
- end = [1 , seq_len , (i + 1 ) * groupsize ])
176
- query_states_to_concat .append (
177
- self .linear (
178
- sub_hidden_states ,
179
- num_heads * head_dim ,
180
- groupsize ,
181
- bias = False ,
182
- wt_dtype = self .dtype ,
183
- scale_factor = (self .group_size == 0 )
184
- )
185
- )
186
- key_states_to_concat .append (
187
- self .linear (
188
- sub_hidden_states ,
189
- num_key_value_heads * head_dim ,
190
- groupsize ,
191
- bias = False ,
192
- wt_dtype = self .dtype ,
193
- scale_factor = (self .group_size == 0 )
194
- )
195
- )
196
- value_states_to_concat .append (
197
- self .linear (
198
- sub_hidden_states ,
199
- num_key_value_heads * head_dim ,
200
- groupsize ,
201
- bias = False ,
202
- wt_dtype = self .dtype ,
203
- scale_factor = (self .group_size == 0 )
204
- )
205
- )
206
- query_states = sum (query_states_to_concat )
207
- key_states = sum (key_states_to_concat )
208
- value_states = sum (value_states_to_concat )
209
- else :
210
- query_states = self .dq_split_linear (hidden_states , num_heads * head_dim ,
211
- hidden_size , self .n_splits_linear ,
212
- wt_dtype = self .dtype ,
213
- scale_factor = (self .group_size == 0 ))
214
- key_states = self .dq_split_linear (hidden_states , num_key_value_heads * head_dim ,
215
- hidden_size , self .n_splits_linear ,
216
- wt_dtype = self .dtype ,
217
- scale_factor = (self .group_size == 0 ))
218
- value_states = self .dq_split_linear (hidden_states , num_key_value_heads * head_dim ,
219
- hidden_size , self .n_splits_linear ,
220
- wt_dtype = self .dtype ,
221
- scale_factor = (self .group_size == 0 ))
168
+ query_states = self .dq_split_linear (hidden_states , num_heads * head_dim ,
169
+ hidden_size , self .n_splits_linear ,
170
+ wt_dtype = self .dtype ,
171
+ scale_factor = (self .group_size == 0 ),
172
+ is_prefill = (mode == "prefill" ))
173
+ key_states = self .dq_split_linear (hidden_states , num_key_value_heads * head_dim ,
174
+ hidden_size , self .n_splits_linear ,
175
+ wt_dtype = self .dtype ,
176
+ scale_factor = (self .group_size == 0 ),
177
+ is_prefill = (mode == "prefill" ))
178
+ value_states = self .dq_split_linear (hidden_states , num_key_value_heads * head_dim ,
179
+ hidden_size , self .n_splits_linear ,
180
+ wt_dtype = self .dtype ,
181
+ scale_factor = (self .group_size == 0 ),
182
+ is_prefill = (mode == "prefill" ))
222
183
223
184
if q_bias is not None :
224
185
query_states = query_states + q_bias
@@ -296,23 +257,10 @@ def attention(self,
296
257
attn_output , hidden_size , hidden_size , bias = False , wt_dtype = self .dtype
297
258
)
298
259
else :
299
- if mode == "prefill" :
300
- attn_output_to_concat = []
301
- for i in range (self .n_splits_linear ):
302
- sub_attn_output = self .slice (attn_output ,
303
- begin = [0 , 0 , i * groupsize ],
304
- end = [1 , seq_len , (i + 1 ) * groupsize ])
305
- attn_output_to_concat .append (
306
- self .linear (
307
- sub_attn_output , hidden_size , groupsize , bias = False ,
308
- wt_dtype = self .dtype , scale_factor = (self .group_size == 0 )
309
- )
310
- )
311
- attn_output = sum (attn_output_to_concat )
312
- else :
313
- attn_output = self .dq_split_linear (attn_output , hidden_size , hidden_size ,
314
- self .n_splits_linear , wt_dtype = self .dtype ,
315
- scale_factor = (self .group_size == 0 ))
260
+ attn_output = self .dq_split_linear (attn_output , hidden_size , hidden_size ,
261
+ self .n_splits_linear , wt_dtype = self .dtype ,
262
+ scale_factor = (self .group_size == 0 ),
263
+ is_prefill = (mode == "prefill" ))
316
264
317
265
return attn_output , new_key_states , new_value_states
318
266
@@ -488,37 +436,14 @@ def mlp(self, hidden_states, seq_len=-1, mode="prefill"):
488
436
mm1 = self .eltwise_mul (self .swish (mm1 ), mm2 ) # type: ignore[attr-defined]
489
437
else :
490
438
invalidInputError (seq_len > 0 , "seq_len should be provided if use split linear" )
491
- if mode == "prefill" :
492
- gate_up_groupsize = self .hidden_size // self .n_splits_linear
493
- mm1_to_concat = []
494
- mm2_to_concat = []
495
- for i in range (self .n_splits_linear ):
496
- sub_hidden_states = self .slice (hidden_states ,
497
- begin = [0 , 0 , i * gate_up_groupsize ],
498
- end = [1 , seq_len , (i + 1 ) * gate_up_groupsize ])
499
- mm1_to_concat .append (
500
- self .linear (
501
- sub_hidden_states , self .intermediate_size , gate_up_groupsize ,
502
- bias = False ,
503
- wt_dtype = self .dtype , scale_factor = (self .group_size == 0 )
504
- )
505
- )
506
- mm2_to_concat .append (
507
- self .linear (
508
- sub_hidden_states , self .intermediate_size , gate_up_groupsize ,
509
- bias = False ,
510
- wt_dtype = self .dtype , scale_factor = (self .group_size == 0 )
511
- )
512
- )
513
- mm1 = sum (mm1_to_concat )
514
- mm2 = sum (mm2_to_concat )
515
- else :
516
- mm1 = self .dq_split_linear (hidden_states , self .intermediate_size , self .hidden_size ,
517
- self .n_splits_linear , wt_dtype = self .dtype ,
518
- scale_factor = (self .group_size == 0 ))
519
- mm2 = self .dq_split_linear (hidden_states , self .intermediate_size , self .hidden_size ,
520
- self .n_splits_linear , wt_dtype = self .dtype ,
521
- scale_factor = (self .group_size == 0 ))
439
+ mm1 = self .dq_split_linear (hidden_states , self .intermediate_size , self .hidden_size ,
440
+ self .n_splits_linear , wt_dtype = self .dtype ,
441
+ scale_factor = (self .group_size == 0 ),
442
+ is_prefill = (mode == "prefill" ))
443
+ mm2 = self .dq_split_linear (hidden_states , self .intermediate_size , self .hidden_size ,
444
+ self .n_splits_linear , wt_dtype = self .dtype ,
445
+ scale_factor = (self .group_size == 0 ),
446
+ is_prefill = (mode == "prefill" ))
522
447
mm1 = self .eltwise_mul (self .swish (mm1 ), mm2 ) # type: ignore[attr-defined]
523
448
524
449
if self .n_splits_down_proj == 1 :
@@ -527,23 +452,10 @@ def mlp(self, hidden_states, seq_len=-1, mode="prefill"):
527
452
)
528
453
else :
529
454
invalidInputError (seq_len > 0 , "seq_len should be provided if use split linear" )
530
- if mode == "prefill" :
531
- down_groupsize = self .intermediate_size // self .n_splits_down_proj
532
- hidden_states_to_concat = []
533
- for i in range (self .n_splits_down_proj ):
534
- sub_mm1 = self .slice (mm1 , begin = [0 , 0 , i * down_groupsize ],
535
- end = [1 , seq_len , (i + 1 ) * down_groupsize ])
536
- hidden_states_to_concat .append (
537
- self .linear (
538
- sub_mm1 , self .hidden_size , down_groupsize , bias = False ,
539
- wt_dtype = self .dtype , scale_factor = (self .group_size == 0 )
540
- )
541
- )
542
- hidden_states = sum (hidden_states_to_concat )
543
- else :
544
- hidden_states = self .dq_split_linear (mm1 , self .hidden_size , self .intermediate_size ,
545
- self .n_splits_down_proj , wt_dtype = self .dtype ,
546
- scale_factor = (self .group_size == 0 ))
455
+ hidden_states = self .dq_split_linear (mm1 , self .hidden_size , self .intermediate_size ,
456
+ self .n_splits_down_proj , wt_dtype = self .dtype ,
457
+ scale_factor = (self .group_size == 0 ),
458
+ is_prefill = (mode == "prefill" ))
547
459
return hidden_states
548
460
549
461
def layer_norm (self , hidden_states , layernorm_weight ):
@@ -660,9 +572,11 @@ def dq_split_linear(self,
660
572
n_splits : int ,
661
573
act_dtype : npt .DTypeLike = np .float16 ,
662
574
wt_dtype : npt .DTypeLike = np .float16 ,
663
- scale_factor : bool = False ):
575
+ scale_factor : bool = False ,
576
+ is_prefill : bool = False ):
664
577
op = super ().dq_split_linear (input_node , n_splits , output_channels , input_channels ,
665
- False , act_dtype , wt_dtype , scale_factor )
578
+ False , act_dtype , wt_dtype , scale_factor ,
579
+ is_prefill = is_prefill )
666
580
self .linear_ops .append (op )
667
581
return op
668
582
0 commit comments