Skip to content

Commit a0c73c2

Browse files
authored
clean NPU code (#12060)
* clean code * remove time.perf_counter()
1 parent c75f3dd commit a0c73c2

File tree

4 files changed

+6
-72
lines changed

4 files changed

+6
-72
lines changed

python/llm/src/ipex_llm/transformers/npu_models/baichuan_mp.py

+3-20
Original file line numberDiff line numberDiff line change
@@ -415,11 +415,6 @@ def forward(
415415
return outputs
416416

417417
def post_forward(self, past_key_value, new_keys, new_values):
418-
key_value_states = []
419-
for i in range(self.intra_stages):
420-
for j in range(1, len(self.backend_decoders[i].torch_out)):
421-
key_value_states.append(self.backend_decoders[i].torch_out[j])
422-
423418
cache_kwargs = {
424419
"max_seq_len": self.max_seq_len,
425420
"transpose": self.transpose_value,
@@ -556,7 +551,6 @@ def run_decode(
556551
head_dim = model.model.layers[layer_start].self_attn.head_dim
557552
rms_norm_eps = model.config.rms_norm_eps
558553
intermediate_size = model.config.intermediate_size
559-
deocderlayers = []
560554
layer_weights = []
561555
input_layer_norm_weights = []
562556
post_attn_layernorm_weights = []
@@ -610,13 +604,12 @@ def run_decode(
610604
with torch.inference_mode():
611605
while True:
612606

613-
dist.broadcast(control, src=0)
607+
dist.broadcast(control, src=0, async_op=False)
614608
if control.item() == -2:
615609
break
616610
elif control.item() == -1:
617611
past_key_values = input_queue.get()
618612
else:
619-
t0 = time.perf_counter()
620613
past_key_values_length = past_key_values.get_seq_length()
621614
seq_length_with_past = 1 + past_key_values_length
622615
position_ids = torch.arange(
@@ -636,7 +629,6 @@ def run_decode(
636629
)
637630
padded_causal_mask[:, :, :, -1] = 0.0
638631
dist.recv(hidden_states, src=rank - 1)
639-
t1 = time.perf_counter()
640632
layer_outputs = multi_decoder(
641633
hidden_states,
642634
attention_mask=padded_causal_mask,
@@ -645,11 +637,8 @@ def run_decode(
645637
output_attentions=False,
646638
use_cache=True,
647639
)
648-
t2 = time.perf_counter()
649640
hidden_states = layer_outputs[0]
650-
t3 = time.perf_counter()
651641
dist.send(hidden_states, dst=(rank + 1) % world_size)
652-
t4 = time.perf_counter()
653642
past_key_values = layer_outputs[1]
654643
new_keys = layer_outputs[2]
655644
new_values = layer_outputs[3]
@@ -674,6 +663,7 @@ def __init__(self, model, max_seq_len, intra_pp=2, inter_pp=2, transpose_value_c
674663
self.input_queues = []
675664
self.output_queues = []
676665
self.decoder_processes = []
666+
self.forward_signal = torch.tensor(0, dtype=torch.int)
677667

678668
for rank in range(1, world_size):
679669
input_q = mp.Queue()
@@ -721,21 +711,17 @@ def forward(
721711
output_attentions: Optional[bool] = False,
722712
use_cache: Optional[bool] = False,
723713
):
724-
t0 = time.perf_counter()
725-
726714
if self.cache_past_key_value != past_key_value:
727715
control = torch.tensor(-1, dtype=torch.int)
728716
dist.broadcast(control, src=0)
729717
for i in range(len(self.decoder_processes)):
730718
self.input_queues[i].put(past_key_value)
731719

732-
control = torch.tensor(0, dtype=torch.int)
733-
dist.broadcast(control, src=0)
720+
dist.broadcast(self.forward_signal, src=0, async_op=True)
734721
hidden_states = hidden_states.to(torch.float16)
735722
dist.send(hidden_states, dst=1)
736723
past_key_value.expand(self.transpose_value_cache)
737724
dist.recv(hidden_states, src=self.world_size - 1)
738-
t1 = time.perf_counter()
739725
return hidden_states, past_key_value
740726

741727
def shutdown(self):
@@ -918,7 +904,6 @@ def baichuan_fused_model_forward(
918904
output_hidden_states: Optional[bool] = None,
919905
return_dict: Optional[bool] = None,
920906
) -> Union[Tuple, BaseModelOutputWithPast]:
921-
t0 = time.perf_counter()
922907
output_attentions = (
923908
output_attentions if output_attentions is not None else self.config.output_attentions
924909
)
@@ -1026,8 +1011,6 @@ def baichuan_fused_model_forward(
10261011
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
10271012
if v is not None
10281013
)
1029-
t1 = time.perf_counter()
1030-
# print("fused model forward time: ", t1 - t0)
10311014
return BaseModelOutputWithPast(
10321015
last_hidden_state=hidden_states,
10331016
past_key_values=next_cache,

python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py

-17
Original file line numberDiff line numberDiff line change
@@ -330,11 +330,6 @@ def forward(
330330
return outputs
331331

332332
def post_forward(self, past_key_value, new_keys, new_values, cache_position):
333-
key_value_states = []
334-
for i in range(self.intra_stages):
335-
for j in range(1, len(self.backend_decoders[i].torch_out)):
336-
key_value_states.append(self.backend_decoders[i].torch_out[j])
337-
338333
cache_kwargs = {
339334
"cache_position": cache_position,
340335
"max_seq_len": self.max_seq_len,
@@ -474,7 +469,6 @@ def run_decode(
474469
head_dim = model.model.layers[layer_start].self_attn.head_dim
475470
rms_norm_eps = model.config.rms_norm_eps
476471
intermediate_size = model.config.intermediate_size
477-
deocderlayers = []
478472
layer_weights = []
479473
input_layer_norm_weights = []
480474
post_attn_layernorm_weights = []
@@ -536,7 +530,6 @@ def run_decode(
536530
elif control.item() == -1:
537531
past_key_values = input_queue.get()
538532
else:
539-
t0 = time.perf_counter()
540533
past_seen_tokens = past_key_values.get_seq_length()
541534
attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64)
542535
cache_position = torch.arange(
@@ -555,7 +548,6 @@ def run_decode(
555548
)
556549
padded_causal_mask[:, :, :, -1] = 0.0
557550
dist.recv(hidden_states, src=rank - 1)
558-
t1 = time.perf_counter()
559551
layer_outputs = multi_decoder(
560552
hidden_states,
561553
attention_mask=padded_causal_mask,
@@ -565,11 +557,8 @@ def run_decode(
565557
use_cache=True,
566558
cache_position=cache_position,
567559
)
568-
t2 = time.perf_counter()
569560
hidden_states = layer_outputs[0]
570-
t3 = time.perf_counter()
571561
dist.send(hidden_states, dst=(rank + 1) % world_size)
572-
t4 = time.perf_counter()
573562
past_key_values = layer_outputs[1]
574563
new_keys = layer_outputs[2]
575564
new_values = layer_outputs[3]
@@ -651,14 +640,11 @@ def forward(
651640
dist.broadcast(control, src=0)
652641
for i in range(len(self.decoder_processes)):
653642
self.input_queues[i].put(past_key_value)
654-
t0 = time.perf_counter()
655643
dist.broadcast(self.forward_signal, src=0, async_op=True)
656-
t1 = time.perf_counter()
657644
hidden_states = hidden_states.to(torch.float16)
658645
dist.send(hidden_states, dst=1)
659646
past_key_value.expand(self.transpose_value_cache)
660647
dist.recv(hidden_states, src=self.world_size - 1)
661-
t2 = time.perf_counter()
662648
return hidden_states, past_key_value
663649

664650
def shutdown(self):
@@ -847,7 +833,6 @@ def llama_fused_model_forward(
847833
return_dict: Optional[bool] = None,
848834
cache_position: Optional[torch.LongTensor] = None,
849835
) -> Union[Tuple, BaseModelOutputWithPast]:
850-
t0 = time.perf_counter()
851836
output_attentions = (
852837
output_attentions if output_attentions is not None else self.config.output_attentions
853838
)
@@ -938,8 +923,6 @@ def llama_fused_model_forward(
938923
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
939924
if v is not None
940925
)
941-
t1 = time.perf_counter()
942-
# print("fused model forward time: ", t1 - t0)
943926
return BaseModelOutputWithPast(
944927
last_hidden_state=hidden_states,
945928
past_key_values=next_cache,

python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py

+3-19
Original file line numberDiff line numberDiff line change
@@ -354,11 +354,6 @@ def forward(
354354
return outputs
355355

356356
def post_forward(self, past_key_value, new_keys, new_values):
357-
key_value_states = []
358-
for i in range(self.intra_stages):
359-
for j in range(1, len(self.backend_decoders[i].torch_out)):
360-
key_value_states.append(self.backend_decoders[i].torch_out[j])
361-
362357
cache_kwargs = {
363358
"max_seq_len": self.max_seq_len,
364359
"transpose": self.transpose_value,
@@ -501,7 +496,6 @@ def run_decode(
501496
rms_norm_eps = model.config.rms_norm_eps
502497
intermediate_size = model.config.intermediate_size
503498
num_hidden_layers = model.config.num_hidden_layers
504-
deocderlayers = []
505499
layer_weights = []
506500
input_layer_norm_weights = []
507501
post_attn_layernorm_weights = []
@@ -559,13 +553,12 @@ def run_decode(
559553
with torch.inference_mode():
560554
while True:
561555

562-
dist.broadcast(control, src=0)
556+
dist.broadcast(control, src=0, async_op=False)
563557
if control.item() == -2:
564558
break
565559
elif control.item() == -1:
566560
past_key_values = input_queue.get()
567561
else:
568-
t0 = time.perf_counter()
569562
past_seen_tokens = past_key_values.get_seq_length()
570563
attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64)
571564
cache_position = torch.arange(
@@ -589,7 +582,6 @@ def run_decode(
589582
)
590583
padded_causal_mask[:, :, :, -1] = 0.0
591584
dist.recv(hidden_states, src=rank - 1)
592-
t1 = time.perf_counter()
593585
layer_outputs = multi_decoder(
594586
hidden_states,
595587
attention_mask=padded_causal_mask,
@@ -598,11 +590,8 @@ def run_decode(
598590
output_attentions=False,
599591
use_cache=True,
600592
)
601-
t2 = time.perf_counter()
602593
hidden_states = layer_outputs[0]
603-
t3 = time.perf_counter()
604594
dist.send(hidden_states, dst=(rank + 1) % world_size)
605-
t4 = time.perf_counter()
606595
past_key_values = layer_outputs[1]
607596
new_keys = layer_outputs[2]
608597
new_values = layer_outputs[3]
@@ -628,6 +617,7 @@ def __init__(self, model, max_seq_len, intra_pp=2, inter_pp=2, transpose_value_c
628617
self.input_queues = []
629618
self.output_queues = []
630619
self.decoder_processes = []
620+
self.forward_signal = torch.tensor(0, dtype=torch.int)
631621

632622
for rank in range(1, world_size):
633623
input_q = mp.Queue()
@@ -677,21 +667,17 @@ def forward(
677667
use_cache: bool = False,
678668
**kwargs,
679669
):
680-
t0 = time.perf_counter()
681-
682670
if self.cache_past_key_value != past_key_value:
683671
control = torch.tensor(-1, dtype=torch.int)
684672
dist.broadcast(control, src=0)
685673
for i in range(len(self.decoder_processes)):
686674
self.input_queues[i].put(past_key_value)
687675

688-
control = torch.tensor(0, dtype=torch.int)
689-
dist.broadcast(control, src=0)
676+
dist.broadcast(self.forward_signal, src=0, async_op=True)
690677
hidden_states = hidden_states.to(torch.float16)
691678
dist.send(hidden_states, dst=1)
692679
past_key_value.expand(self.transpose_value_cache)
693680
dist.recv(hidden_states, src=self.world_size - 1)
694-
t1 = time.perf_counter()
695681
return hidden_states, past_key_value
696682

697683
def shutdown(self):
@@ -889,7 +875,6 @@ def minicpm_fused_model_forward(
889875
output_hidden_states: Optional[bool] = None,
890876
return_dict: Optional[bool] = None,
891877
) -> Union[Tuple, BaseModelOutputWithPast]:
892-
t0 = time.perf_counter()
893878
output_attentions = (
894879
output_attentions if output_attentions is not None
895880
else self.config.output_attentions
@@ -978,7 +963,6 @@ def minicpm_fused_model_forward(
978963
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
979964
if v is not None
980965
)
981-
t1 = time.perf_counter()
982966
return BaseModelOutputWithPast(
983967
last_hidden_state=hidden_states,
984968
past_key_values=next_cache,

python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py

-16
Original file line numberDiff line numberDiff line change
@@ -412,11 +412,6 @@ def forward(
412412
return outputs
413413

414414
def post_forward(self, past_key_value, new_keys, new_values):
415-
key_value_states = []
416-
for i in range(self.intra_stages):
417-
for j in range(1, len(self.backend_decoders[i].torch_out)):
418-
key_value_states.append(self.backend_decoders[i].torch_out[j])
419-
420415
cache_kwargs = {
421416
"max_seq_len": self.max_seq_len,
422417
"transpose": self.transpose_value,
@@ -555,7 +550,6 @@ def run_decode(
555550
head_dim = model.model.layers[layer_start].self_attn.head_dim
556551
rms_norm_eps = model.config.rms_norm_eps
557552
intermediate_size = model.config.intermediate_size
558-
deocderlayers = []
559553
layer_weights = []
560554
input_layer_norm_weights = []
561555
post_attn_layernorm_weights = []
@@ -640,7 +634,6 @@ def run_decode(
640634
elif control.item() == -1:
641635
past_key_values = input_queue.get()
642636
else:
643-
t0 = time.perf_counter()
644637
past_seen_tokens = past_key_values.get_seq_length()
645638
attention_mask = torch.ones([1, past_seen_tokens + 1], dtype=torch.int64)
646639
position_ids = torch.arange(
@@ -669,7 +662,6 @@ def run_decode(
669662
)
670663
padded_causal_mask[:, :, :, -1] = 0.0
671664
dist.recv(hidden_states, src=rank - 1)
672-
t1 = time.perf_counter()
673665
layer_outputs = multi_decoder(
674666
hidden_states,
675667
attention_mask=padded_causal_mask,
@@ -678,11 +670,8 @@ def run_decode(
678670
output_attentions=False,
679671
use_cache=True,
680672
)
681-
t2 = time.perf_counter()
682673
hidden_states = layer_outputs[0]
683-
t3 = time.perf_counter()
684674
dist.send(hidden_states, dst=(rank + 1) % world_size)
685-
t4 = time.perf_counter()
686675
past_key_values = layer_outputs[1]
687676
new_keys = layer_outputs[2]
688677
new_values = layer_outputs[3]
@@ -757,22 +746,17 @@ def forward(
757746
use_cache: bool = False,
758747
**kwargs,
759748
):
760-
t0 = time.perf_counter()
761-
762749
if self.cache_past_key_value != past_key_value:
763750
control = torch.tensor(-1, dtype=torch.int)
764751
dist.broadcast(control, src=0)
765752
for i in range(len(self.decoder_processes)):
766753
self.input_queues[i].put(past_key_value)
767754

768-
t0 = time.perf_counter()
769755
dist.broadcast(self.forward_signal, src=0, async_op=True)
770-
t1 = time.perf_counter()
771756
hidden_states = hidden_states.to(torch.float16)
772757
dist.send(hidden_states, dst=1)
773758
past_key_value.expand(self.transpose_value_cache)
774759
dist.recv(hidden_states, src=self.world_size - 1)
775-
t2 = time.perf_counter()
776760
return hidden_states, past_key_value
777761

778762
def shutdown(self):

0 commit comments

Comments
 (0)