Skip to content

Commit 70037ad

Browse files
authored
Groupwise prefill optimization (#12291)
* except lm_head * remove * support gw lm_head * update * fix * remove run.bat * fix style * support llama3 * slice -> split * remove debug * fix style * add dpu
1 parent 540eaeb commit 70037ad

File tree

3 files changed

+106
-148
lines changed

3 files changed

+106
-148
lines changed

python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,10 @@ def __init__(
188188
new_value_states = self.convert_to_fp16(curr_key_values[i][1])
189189

190190
print("start compiling")
191-
self.compile()
191+
if mode == "prefill":
192+
self.compile(npu_dpu_groups=6)
193+
else:
194+
self.compile()
192195

193196
def build_decoder(
194197
self,
@@ -753,19 +756,40 @@ def run_prefill(
753756

754757
weights = []
755758

756-
for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
757-
attn_layer.v_proj_dq_list):
758-
weights.append((q.weight, q.scale))
759-
weights.append((k.weight, k.scale))
760-
weights.append((v.weight, v.scale))
761-
762-
for l in attn_layer.o_proj_dq_list:
763-
weights.append((l.weight, l.scale))
764-
for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
765-
weights.append((g.weight, g.scale))
766-
weights.append((u.weight, u.scale))
767-
for l in mlp_layer.down_proj_dq_list:
768-
weights.append((l.weight, l.scale))
759+
if n_splits_linear == 1:
760+
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
761+
attn_layer.k_proj_dq_list,
762+
attn_layer.v_proj_dq_list,
763+
attn_layer.o_proj_dq_list,
764+
mlp_layer.gate_proj_dq_list,
765+
mlp_layer.up_proj_dq_list):
766+
weights.append((q.weight, q.scale))
767+
weights.append((k.weight, k.scale))
768+
weights.append((v.weight, v.scale))
769+
weights.append((o.weight, o.scale))
770+
weights.append((g.weight, g.scale))
771+
weights.append((u.weight, u.scale))
772+
else:
773+
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
774+
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
775+
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
776+
l_weights = []
777+
scales = []
778+
for l in layer_list:
779+
l_weights.append(l.weight)
780+
scales.append(l.scale)
781+
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
782+
783+
if n_splits_down_proj == 1:
784+
for l in mlp_layer.down_proj_dq_list:
785+
weights.append((l.weight, l.scale))
786+
else:
787+
l_weights = []
788+
scales = []
789+
for l in mlp_layer.down_proj_dq_list:
790+
l_weights.append(l.weight)
791+
scales.append(l.scale)
792+
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
769793

770794
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
771795
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)

python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py

Lines changed: 35 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -165,60 +165,21 @@ def attention(self,
165165
)
166166
else:
167167
hidden_states = self.unsqueeze(hidden_states, axis=0)
168-
if mode == "prefill":
169-
query_states_to_concat = []
170-
key_states_to_concat = []
171-
value_states_to_concat = []
172-
for i in range(self.n_splits_linear):
173-
sub_hidden_states = self.slice(hidden_states,
174-
begin=[0, 0, i * groupsize],
175-
end=[1, seq_len, (i + 1) * groupsize])
176-
query_states_to_concat.append(
177-
self.linear(
178-
sub_hidden_states,
179-
num_heads * head_dim,
180-
groupsize,
181-
bias=False,
182-
wt_dtype=self.dtype,
183-
scale_factor=(self.group_size == 0)
184-
)
185-
)
186-
key_states_to_concat.append(
187-
self.linear(
188-
sub_hidden_states,
189-
num_key_value_heads * head_dim,
190-
groupsize,
191-
bias=False,
192-
wt_dtype=self.dtype,
193-
scale_factor=(self.group_size == 0)
194-
)
195-
)
196-
value_states_to_concat.append(
197-
self.linear(
198-
sub_hidden_states,
199-
num_key_value_heads * head_dim,
200-
groupsize,
201-
bias=False,
202-
wt_dtype=self.dtype,
203-
scale_factor=(self.group_size == 0)
204-
)
205-
)
206-
query_states = sum(query_states_to_concat)
207-
key_states = sum(key_states_to_concat)
208-
value_states = sum(value_states_to_concat)
209-
else:
210-
query_states = self.dq_split_linear(hidden_states, num_heads * head_dim,
211-
hidden_size, self.n_splits_linear,
212-
wt_dtype=self.dtype,
213-
scale_factor=(self.group_size == 0))
214-
key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
215-
hidden_size, self.n_splits_linear,
216-
wt_dtype=self.dtype,
217-
scale_factor=(self.group_size == 0))
218-
value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
219-
hidden_size, self.n_splits_linear,
220-
wt_dtype=self.dtype,
221-
scale_factor=(self.group_size == 0))
168+
query_states = self.dq_split_linear(hidden_states, num_heads * head_dim,
169+
hidden_size, self.n_splits_linear,
170+
wt_dtype=self.dtype,
171+
scale_factor=(self.group_size == 0),
172+
is_prefill=(mode == "prefill"))
173+
key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
174+
hidden_size, self.n_splits_linear,
175+
wt_dtype=self.dtype,
176+
scale_factor=(self.group_size == 0),
177+
is_prefill=(mode == "prefill"))
178+
value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
179+
hidden_size, self.n_splits_linear,
180+
wt_dtype=self.dtype,
181+
scale_factor=(self.group_size == 0),
182+
is_prefill=(mode == "prefill"))
222183

223184
if q_bias is not None:
224185
query_states = query_states + q_bias
@@ -296,23 +257,10 @@ def attention(self,
296257
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype
297258
)
298259
else:
299-
if mode == "prefill":
300-
attn_output_to_concat = []
301-
for i in range(self.n_splits_linear):
302-
sub_attn_output = self.slice(attn_output,
303-
begin=[0, 0, i * groupsize],
304-
end=[1, seq_len, (i + 1) * groupsize])
305-
attn_output_to_concat.append(
306-
self.linear(
307-
sub_attn_output, hidden_size, groupsize, bias=False,
308-
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
309-
)
310-
)
311-
attn_output = sum(attn_output_to_concat)
312-
else:
313-
attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size,
314-
self.n_splits_linear, wt_dtype=self.dtype,
315-
scale_factor=(self.group_size == 0))
260+
attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size,
261+
self.n_splits_linear, wt_dtype=self.dtype,
262+
scale_factor=(self.group_size == 0),
263+
is_prefill=(mode == "prefill"))
316264

317265
return attn_output, new_key_states, new_value_states
318266

@@ -488,37 +436,14 @@ def mlp(self, hidden_states, seq_len=-1, mode="prefill"):
488436
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
489437
else:
490438
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
491-
if mode == "prefill":
492-
gate_up_groupsize = self.hidden_size // self.n_splits_linear
493-
mm1_to_concat = []
494-
mm2_to_concat = []
495-
for i in range(self.n_splits_linear):
496-
sub_hidden_states = self.slice(hidden_states,
497-
begin=[0, 0, i * gate_up_groupsize],
498-
end=[1, seq_len, (i + 1) * gate_up_groupsize])
499-
mm1_to_concat.append(
500-
self.linear(
501-
sub_hidden_states, self.intermediate_size, gate_up_groupsize,
502-
bias=False,
503-
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
504-
)
505-
)
506-
mm2_to_concat.append(
507-
self.linear(
508-
sub_hidden_states, self.intermediate_size, gate_up_groupsize,
509-
bias=False,
510-
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
511-
)
512-
)
513-
mm1 = sum(mm1_to_concat)
514-
mm2 = sum(mm2_to_concat)
515-
else:
516-
mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
517-
self.n_splits_linear, wt_dtype=self.dtype,
518-
scale_factor=(self.group_size == 0))
519-
mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
520-
self.n_splits_linear, wt_dtype=self.dtype,
521-
scale_factor=(self.group_size == 0))
439+
mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
440+
self.n_splits_linear, wt_dtype=self.dtype,
441+
scale_factor=(self.group_size == 0),
442+
is_prefill=(mode == "prefill"))
443+
mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
444+
self.n_splits_linear, wt_dtype=self.dtype,
445+
scale_factor=(self.group_size == 0),
446+
is_prefill=(mode == "prefill"))
522447
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
523448

524449
if self.n_splits_down_proj == 1:
@@ -527,23 +452,10 @@ def mlp(self, hidden_states, seq_len=-1, mode="prefill"):
527452
)
528453
else:
529454
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
530-
if mode == "prefill":
531-
down_groupsize = self.intermediate_size // self.n_splits_down_proj
532-
hidden_states_to_concat = []
533-
for i in range(self.n_splits_down_proj):
534-
sub_mm1 = self.slice(mm1, begin=[0, 0, i * down_groupsize],
535-
end=[1, seq_len, (i + 1) * down_groupsize])
536-
hidden_states_to_concat.append(
537-
self.linear(
538-
sub_mm1, self.hidden_size, down_groupsize, bias=False,
539-
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
540-
)
541-
)
542-
hidden_states = sum(hidden_states_to_concat)
543-
else:
544-
hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size,
545-
self.n_splits_down_proj, wt_dtype=self.dtype,
546-
scale_factor=(self.group_size == 0))
455+
hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size,
456+
self.n_splits_down_proj, wt_dtype=self.dtype,
457+
scale_factor=(self.group_size == 0),
458+
is_prefill=(mode == "prefill"))
547459
return hidden_states
548460

549461
def layer_norm(self, hidden_states, layernorm_weight):
@@ -660,9 +572,11 @@ def dq_split_linear(self,
660572
n_splits: int,
661573
act_dtype: npt.DTypeLike = np.float16,
662574
wt_dtype: npt.DTypeLike = np.float16,
663-
scale_factor: bool = False):
575+
scale_factor: bool = False,
576+
is_prefill: bool = False):
664577
op = super().dq_split_linear(input_node, n_splits, output_channels, input_channels,
665-
False, act_dtype, wt_dtype, scale_factor)
578+
False, act_dtype, wt_dtype, scale_factor,
579+
is_prefill=is_prefill)
666580
self.linear_ops.append(op)
667581
return op
668582

python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -827,20 +827,40 @@ def run_prefill(
827827
mlp_layer = curr_layer.mlp
828828

829829
weights = []
830+
if n_splits_linear == 1:
831+
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
832+
attn_layer.k_proj_dq_list,
833+
attn_layer.v_proj_dq_list,
834+
attn_layer.o_proj_dq_list,
835+
mlp_layer.gate_proj_dq_list,
836+
mlp_layer.up_proj_dq_list):
837+
weights.append((q.weight, q.scale))
838+
weights.append((k.weight, k.scale))
839+
weights.append((v.weight, v.scale))
840+
weights.append((o.weight, o.scale))
841+
weights.append((g.weight, g.scale))
842+
weights.append((u.weight, u.scale))
843+
else:
844+
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
845+
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
846+
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
847+
l_weights = []
848+
scales = []
849+
for l in layer_list:
850+
l_weights.append(l.weight)
851+
scales.append(l.scale)
852+
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
830853

831-
for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
832-
attn_layer.v_proj_dq_list):
833-
weights.append((q.weight, q.scale))
834-
weights.append((k.weight, k.scale))
835-
weights.append((v.weight, v.scale))
836-
837-
for l in attn_layer.o_proj_dq_list:
838-
weights.append((l.weight, l.scale))
839-
for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
840-
weights.append((g.weight, g.scale))
841-
weights.append((u.weight, u.scale))
842-
for l in mlp_layer.down_proj_dq_list:
843-
weights.append((l.weight, l.scale))
854+
if n_splits_down_proj == 1:
855+
for l in mlp_layer.down_proj_dq_list:
856+
weights.append((l.weight, l.scale))
857+
else:
858+
l_weights = []
859+
scales = []
860+
for l in mlp_layer.down_proj_dq_list:
861+
l_weights.append(l.weight)
862+
scales.append(l.scale)
863+
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
844864

845865
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
846866
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)

0 commit comments

Comments
 (0)