Skip to content

Commit 41b8064

Browse files
authored
Support minicpm-1B in level0 pipeline (#12297)
1 parent 46d8300 commit 41b8064

File tree

7 files changed

+435
-71
lines changed

7 files changed

+435
-71
lines changed

python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ In this directory, you will find examples on how to directly run HuggingFace `tr
99
| Llama2 | [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) |
1010
| Llama3 | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) |
1111
| Baichuan2 | [baichuan-inc/Baichuan2-7B-Chat](https://huggingface.co/baichuan-inc/Baichuan-7B-Chat) |
12+
| MiniCPM | [openbmb/MiniCPM-1B-sft-bf16](https://huggingface.co/openbmb/MiniCPM-1B-sft-bf16) |
1213

1314
## 0. Requirements
1415
To run these examples with IPEX-LLM on Intel NPUs, make sure to install the newest driver version of Intel NPU.
@@ -47,6 +48,9 @@ python llama3.py
4748
4849
:: to run Baichuan2-7B-Chat
4950
python baichuan2.py
51+
52+
:: to run MiniCPM-1B-sft-bf16
53+
python minicpm.py
5054
```
5155

5256
Arguments info:
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#
2+
# Copyright 2016 The BigDL Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
18+
import torch
19+
import time
20+
import argparse
21+
from ipex_llm.transformers.npu_model import AutoModelForCausalLM
22+
from transformers import AutoTokenizer
23+
from transformers.utils import logging
24+
import os
25+
26+
logger = logging.get_logger(__name__)
27+
28+
if __name__ == "__main__":
29+
parser = argparse.ArgumentParser(
30+
description="Predict Tokens using `generate()` API for npu model"
31+
)
32+
parser.add_argument(
33+
"--repo-id-or-model-path",
34+
type=str,
35+
default="openbmb/MiniCPM-1B-sft-bf16",
36+
help="The huggingface repo id for the MiniCPM model to be downloaded"
37+
", or the path to the huggingface checkpoint folder",
38+
)
39+
parser.add_argument("--lowbit-path", type=str,
40+
default="",
41+
help="The path to the lowbit model folder, leave blank if you do not want to save. \
42+
If path not exists, lowbit model will be saved there. \
43+
Else, lowbit model will be loaded.",
44+
)
45+
parser.add_argument('--prompt', type=str, default="What is AI?",
46+
help='Prompt to infer')
47+
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
48+
parser.add_argument("--max-context-len", type=int, default=1024)
49+
parser.add_argument("--max-prompt-len", type=int, default=512)
50+
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
51+
52+
args = parser.parse_args()
53+
model_path = args.repo_id_or_model_path
54+
55+
if not args.lowbit_path or not os.path.exists(args.lowbit_path):
56+
model = AutoModelForCausalLM.from_pretrained(model_path,
57+
optimize_model=True,
58+
pipeline=True,
59+
max_context_len=args.max_context_len,
60+
max_prompt_len=args.max_prompt_len,
61+
torch_dtype=torch.float16,
62+
attn_implementation="eager",
63+
transpose_value_cache=not args.disable_transpose_value_cache,
64+
trust_remote_code=True)
65+
else:
66+
model = AutoModelForCausalLM.load_low_bit(
67+
args.lowbit_path,
68+
attn_implementation="eager",
69+
torch_dtype=torch.float16,
70+
max_context_len=args.max_context_len,
71+
max_prompt_len=args.max_prompt_len,
72+
pipeline=True,
73+
transpose_value_cache=not args.disable_transpose_value_cache,
74+
trust_remote_code=True
75+
)
76+
77+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
78+
79+
if args.lowbit_path and not os.path.exists(args.lowbit_path):
80+
model.save_low_bit(args.lowbit_path)
81+
82+
print("-" * 80)
83+
print("done")
84+
with torch.inference_mode():
85+
print("finish to load")
86+
for i in range(5):
87+
prompt = "<用户>{}<AI>".format(args.prompt)
88+
_input_ids = tokenizer.encode(prompt, return_tensors="pt")
89+
print("input length:", len(_input_ids[0]))
90+
st = time.time()
91+
output = model.generate(
92+
_input_ids, max_new_tokens=args.n_predict, do_print=True
93+
)
94+
end = time.time()
95+
print(f"Inference time: {end-st} s")
96+
input_str = tokenizer.decode(_input_ids[0], skip_special_tokens=False)
97+
print("-" * 20, "Input", "-" * 20)
98+
print(input_str)
99+
output_str = tokenizer.decode(output[0], skip_special_tokens=False)
100+
print("-" * 20, "Output", "-" * 20)
101+
print(output_str)
102+
103+
print("-" * 80)
104+
print("done")
105+
print("success shut down")

python/llm/example/NPU/HF-Transformers-AutoModels/LLM/minicpm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292

9393
print("finish to load")
9494
for i in range(5):
95-
_input_ids = tokenizer.encode("<用户>{}".format(args.prompt), return_tensors="pt")
95+
_input_ids = tokenizer.encode("<用户>{}<AI>".format(args.prompt), return_tensors="pt")
9696
print("input length:", len(_input_ids[0]))
9797
st = time.time()
9898
output = model.generate(

python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,46 @@ def convert_baichuan(
227227
convert_forward(model, module.BaichuanModel, baichuan_model_forward)
228228

229229

230+
def convert_minicpm(
231+
model: torch.nn.Module,
232+
max_output_len=1024,
233+
max_prompt_len=1024,
234+
decoder=False,
235+
inter_pp=None,
236+
intra_pp=None,
237+
transpose_value_cache=True,
238+
):
239+
from ipex_llm.transformers.npu_models.minicpm_mp import gen_minicpm_fused_model_forward
240+
from ipex_llm.transformers.npu_models.minicpm_mp import DecodeRunner, PrefillRunner
241+
modeling_module_name = model.__class__.__module__
242+
module = importlib.import_module(modeling_module_name)
243+
244+
if decoder:
245+
decode_runner = DecodeRunner(
246+
model,
247+
max_seq_len=max_output_len,
248+
inter_pp=inter_pp,
249+
intra_pp=intra_pp,
250+
transpose_value_cache=transpose_value_cache,
251+
)
252+
else:
253+
decode_runner = None
254+
prefill_runner = PrefillRunner(
255+
model,
256+
max_output_len=max_output_len,
257+
max_prompt_len=max_prompt_len,
258+
transpose_value_cache=transpose_value_cache,
259+
)
260+
minicpm_model_forward = gen_minicpm_fused_model_forward(
261+
prefill_runner=prefill_runner, decode_runner=decode_runner
262+
)
263+
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
264+
if model.config.num_hidden_layers == 40:
265+
# for minicpm-2b
266+
from ipex_llm.transformers.npu_models.minicpm_mp import minicpm_casullm_forward
267+
convert_forward(model, module.MiniCPMForCausalLM, minicpm_casullm_forward)
268+
269+
230270
def optimize_llm(
231271
model: torch.nn.Module,
232272
max_context_len=1024,
@@ -291,41 +331,13 @@ def optimize_llm(
291331
intra_pp = 2
292332
if inter_pp is None:
293333
inter_pp = 2
294-
295-
from ipex_llm.transformers.npu_models.minicpm_mp import gen_minicpm_fused_model_forward
296-
from ipex_llm.transformers.npu_models.minicpm_mp import DecodeRunner, PrefillRunner
297-
298-
modeling_module_name = model.__class__.__module__
299-
module = importlib.import_module(modeling_module_name)
300-
301-
if model.config.num_hidden_layers == 52:
302-
# for minicpm-1b
303-
transpose_cache = transpose_value_cache
304-
elif model.config.num_hidden_layers == 40:
305-
# for minicpm-2b
306-
transpose_cache = False
307-
308-
decode_runner = DecodeRunner(
309-
model,
310-
max_seq_len=max_context_len,
311-
inter_pp=inter_pp,
312-
intra_pp=intra_pp,
313-
transpose_value_cache=transpose_cache,
314-
)
315-
prefill_runner = PrefillRunner(
316-
model,
317-
max_output_len=max_context_len,
318-
max_prompt_len=max_prompt_len,
319-
transpose_value_cache=transpose_cache,
320-
)
321-
minicpm_model_forward = gen_minicpm_fused_model_forward(
322-
prefill_runner=prefill_runner, decode_runner=decode_runner
323-
)
324-
convert_forward(model, module.MiniCPMModel, minicpm_model_forward)
325-
if model.config.num_hidden_layers == 40:
326-
# for minicpm-2b
327-
from ipex_llm.transformers.npu_models.minicpm_mp import minicpm_casullm_forward
328-
convert_forward(model, module.MiniCPMForCausalLM, minicpm_casullm_forward)
334+
convert_minicpm(model,
335+
max_output_len=max_context_len,
336+
max_prompt_len=max_prompt_len,
337+
inter_pp=inter_pp,
338+
intra_pp=intra_pp,
339+
decoder=True,
340+
transpose_value_cache=transpose_value_cache)
329341
elif model.config.model_type == "baichuan" and model.config.num_hidden_layers == 32:
330342
# for Baichuan2-7B
331343
if intra_pp is None:
@@ -339,7 +351,7 @@ def optimize_llm(
339351
intra_pp=intra_pp,
340352
decoder=True,
341353
transpose_value_cache=transpose_value_cache)
342-
if isinstance(model.lm_head, SlicedLMHead):
354+
if hasattr(model, 'lm_head') and isinstance(model.lm_head, SlicedLMHead):
343355
model.lm_head.get_fused_lm_head()
344356

345357

python/llm/src/ipex_llm/transformers/npu_models/minicpm_mp.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from torch.nn import CrossEntropyLoss
5555

5656

57-
class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
57+
class LowBitMinicpmMultiDecoderlayer(LLMBaseNNFactory):
5858
def __init__(
5959
self,
6060
# batch_size: int,
@@ -118,31 +118,13 @@ def __init__(
118118

119119
# Self Attention
120120
if mode == "decode":
121-
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1))
121+
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
122+
dtype=np.int64)
122123
else:
123-
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len))
124+
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len),
125+
dtype=np.int64)
124126

125-
position_ids = self.create_input_op((self.batch_size, self.seq_len))
126-
past_keys = []
127-
past_values = []
128-
if mode == "decode":
129-
for i in range(num_layers):
130-
past_key = self.create_cache_op(
131-
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
132-
)
133-
if transpose_value:
134-
past_value = self.create_cache_op(
135-
(self.batch_size, self.num_key_value_heads, self.head_dim, self.max_seq_len)
136-
)
137-
else:
138-
past_value = self.create_cache_op(
139-
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
140-
)
141-
past_keys.append(past_key)
142-
past_values.append(past_value)
143-
else:
144-
past_keys = [None] * num_layers
145-
past_values = [None] * num_layers
127+
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
146128

147129
if input_layernorm_weights is None:
148130
input_layernorm_weights = []
@@ -168,6 +150,27 @@ def __init__(
168150
input_layernorm_weights = [self.constant(w) for w in input_layernorm_weights]
169151
post_attn_layernorm_weights = [self.constant(w) for w in post_attn_layernorm_weights]
170152

153+
past_keys = []
154+
past_values = []
155+
if mode == "decode":
156+
for i in range(num_layers):
157+
past_key = self.create_cache_op(
158+
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
159+
)
160+
if transpose_value:
161+
past_value = self.create_cache_op(
162+
(self.batch_size, self.num_key_value_heads, self.head_dim, self.max_seq_len)
163+
)
164+
else:
165+
past_value = self.create_cache_op(
166+
(self.batch_size, self.num_key_value_heads, self.max_seq_len, self.head_dim)
167+
)
168+
past_keys.append(past_key)
169+
past_values.append(past_value)
170+
else:
171+
past_keys = [None] * num_layers
172+
past_values = [None] * num_layers
173+
171174
hidden_states = input
172175

173176
curr_key_values = []
@@ -297,7 +300,7 @@ def __init__(
297300
start, end = self.layer_ranges[i]
298301
lm_0 = input_laynorm_weights[start:end]
299302
lm_1 = post_attn_layernorm_weights[start:end]
300-
decoder = LowBitLlamaMultiDecoderlayer(
303+
decoder = LowBitMinicpmMultiDecoderlayer(
301304
[1, 1, num_heads * head_dim],
302305
input_layernorm_weights=lm_0,
303306
post_attn_layernorm_weights=lm_1,
@@ -334,15 +337,15 @@ def forward(
334337

335338
inputs = (
336339
hidden_states.to(torch.float16),
337-
attention_mask,
338-
position_ids.to(torch.float16),
340+
attention_mask.to(torch.int64),
341+
position_ids.to(torch.int64),
339342
)
340343

341344
for i in range(self.intra_stages):
342345
start, end = self.layer_ranges[i]
343346
self.backend_decoders[i].update_cache(past_key_value, self.layer_indexes[start:end])
344347

345-
hidden_states, new_keys, new_values = LowBitLlamaMultiDecoderlayer.run_decoders(
348+
hidden_states, new_keys, new_values = LowBitMinicpmMultiDecoderlayer.run_decoders(
346349
inputs,
347350
decoders=self.backend_decoders)
348351

@@ -403,7 +406,7 @@ def __init__(
403406
np_dtype = np.float16
404407

405408
self.backend_cls_prefill = partial(
406-
LowBitLlamaMultiDecoderlayer,
409+
LowBitMinicpmMultiDecoderlayer,
407410
num_heads=num_heads,
408411
num_key_value_heads=num_key_value_heads,
409412
num_layers=1,
@@ -445,7 +448,9 @@ def forward(
445448
seq_len = hidden_states.shape[1]
446449

447450
backend_cls = self.backend_cls_prefill
448-
inputs = (hidden_states.to(torch.float16), attention_mask, position_ids.to(torch.float16))
451+
inputs = (hidden_states.to(torch.float16),
452+
attention_mask.to(torch.int64),
453+
position_ids.to(torch.int64))
449454
inputs += (self.layer_norm_0, self.layer_norm_1)
450455
hidden_states, past_key, past_value = run_model(
451456
inputs, self.op_parameters, backend_cls, self.op_id, replica=2
@@ -578,9 +583,9 @@ def run_decode(
578583

579584
pad_mask = (0, pad_len)
580585
padded_causal_mask = F.pad(
581-
causal_mask.to(torch.float16), pad_mask, value=torch.finfo(torch.float16).min
586+
causal_mask.to(torch.int64), pad_mask, value=torch.iinfo(torch.int64).min
582587
)
583-
padded_causal_mask[:, :, :, -1] = 0.0
588+
padded_causal_mask[:, :, :, -1] = 0
584589
dist.recv(hidden_states, src=rank - 1)
585590
layer_outputs = multi_decoder(
586591
hidden_states,
@@ -831,9 +836,9 @@ def forward(
831836
hidden_states = F.pad(hidden_states.to(torch.float16), (0, 0, 0, pad_len), value=0.0)
832837
position_ids = F.pad(position_ids, (0, pad_len), value=0)
833838
attention_mask = F.pad(
834-
attention_mask.to(torch.float16),
839+
attention_mask.to(torch.int64),
835840
(0, pad_len, 0, pad_len),
836-
value=torch.finfo(torch.float16).min,
841+
value=torch.iinfo(torch.int64).min,
837842
)
838843

839844
args = (hidden_states, position_ids, attention_mask, past_key_value)

0 commit comments

Comments
 (0)