Skip to content

Commit 68a7707

Browse files
Add moonlight GPU example (#12929)
* Add moonlight GPU example and update table * Small fix * Fix based on comments * Small fix
1 parent 33da3a3 commit 68a7707

File tree

6 files changed

+263
-0
lines changed

6 files changed

+263
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ Over 70 models have been optimized/verified on `ipex-llm`, including *LLaMA/LLaM
337337
| MiniCPM-V-2_6 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/minicpm-v-2_6) | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-V-2_6) | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |
338338
| MiniCPM-o-2_6 | | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/) |
339339
| Janus-Pro | | [link](python/llm/example/GPU/HuggingFace/Multimodal/janus-pro/) |
340+
| Moonlight | |[link](python/llm/example/GPU/HuggingFace/LLM/moonlight/) |
340341
| StableDiffusion | | [link](python/llm/example/GPU/HuggingFace/Multimodal/StableDiffusion) |
341342
| Bce-Embedding-Base-V1 | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Embedding) |
342343
| Speech_Paraformer-Large | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |

README.zh-CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ See the demo of running [*Text-Generation-WebUI*](https://ipex-llm.readthedocs.i
337337
| MiniCPM-V-2_6 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/minicpm-v-2_6) | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-V-2_6) | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |
338338
| MiniCPM-o-2_6 | | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/) |
339339
| Janus-Pro | | [link](python/llm/example/GPU/HuggingFace/Multimodal/janus-pro/) |
340+
| Moonlight | |[link](python/llm/example/GPU/HuggingFace/LLM/moonlight/) |
340341
| StableDiffusion | | [link](python/llm/example/GPU/HuggingFace/Multimodal/StableDiffusion) |
341342
| Bce-Embedding-Base-V1 | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Embedding) |
342343
| Speech_Paraformer-Large | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Moonlight
2+
3+
In this directory, you will find examples on how you could apply IPEX-LLM INT4 optimizations on Moonlight model on [Intel GPUs](../../../README.md). For illustration purposes, we utilize [moonshotai/Moonlight-16B-A3B-Instruct](https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct) as reference Moonlight model.
4+
5+
## 0. Requirements & Installation
6+
7+
To run these examples with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../../README.md#requirements) for more information.
8+
9+
### 0.1 Installation
10+
11+
```bash
12+
conda create -n llm python=3.11
13+
conda activate llm
14+
15+
# install IPEX-LLM with PyTorch 2.6 supports
16+
pip install --pre --upgrade ipex-llm[xpu_2.6] --extra-index-url https://download.pytorch.org/whl/xpu
17+
18+
pip install transformers==4.45.0
19+
pip install accelerate==0.33.0
20+
pip install "trl<0.12.0"
21+
22+
pip install tiktoken blobfile
23+
```
24+
25+
### 0.2 Runtime Configuration
26+
27+
- For Windows users:
28+
```cmd
29+
set SYCL_CACHE_PERSISTENT=1
30+
:: optional
31+
set SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
32+
```
33+
34+
- For Linux users:
35+
```cmd
36+
unset OCL_ICD_VENDOR
37+
export SYCL_CACHE_PERSISTENT=1
38+
# optional
39+
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
40+
```
41+
42+
> [!NOTE]
43+
> The environment variable `SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS` determines the usage of immediate command lists for task submission to the GPU. Enabling this mode may improve performance, but sometimes this may also cause performance degradation. Please consider experimenting with and without this environment variable for best performance. For more details, you can refer to [this article](https://www.intel.com/content/www/us/en/developer/articles/guide/level-zero-immediate-command-lists.html)
44+
45+
## 1. Download & Convert Model
46+
47+
To run the Moonlight model with IPEX-LLM optimizations, we need to download and convert first it to make sure it could be successfully loaded by `transformers`.
48+
49+
### 1.1 Download Model
50+
51+
To download [moonshotai/Moonlight-16B-A3B-Instruct](https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct) from Hugging Face, you could use [download.py](./download.py) through:
52+
53+
```bash
54+
download.py --repo-id moonshotai/Moonlight-16B-A3B-Instruct --commit-id 95583251e616c46a80715897a705cd38659afc27
55+
```
56+
57+
By default, Moonlight-16B-A3B-Instruct will be downloaded to the current folder. You could also define the download folder path by `--download-dir-path DOWNLOAD_DIR_PATH`.
58+
59+
> [!TIP]
60+
> Refer to [here](https://huggingface.co/docs/hub/en/models-downloading) for althernative methods to download models from Hugging Face.
61+
>
62+
> For [moonshotai/Moonlight-16B-A3B-Instruct](https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct), please make sure to use its revision/commit id `95583251e616c46a80715897a705cd38659afc27`.
63+
64+
### 1.2 Convert Model
65+
66+
Next, convert the downloaded model by [convert.py](./convert.py):
67+
68+
```bash
69+
convert.py --model-path DOWNLOAD_DIR_PATH
70+
```
71+
72+
The converted model will be saved at `<DOWNLOAD_DIR_PATH>-converted`.
73+
74+
## 2. Example: Predict Tokens using `generate()` API
75+
76+
In the example [generate.py](./generate.py), we show a basic use case for a Moonlight model to predict the next N tokens using `generate()` API, with IPEX-LLM INT4 optimizations on Intel GPUs.
77+
78+
### 2.1 Running example
79+
80+
```bash
81+
python generate.py --converted-model-path `<DOWNLOAD_DIR_PATH>-converted` --prompt PROMPT --n-predict N_PREDICT
82+
```
83+
84+
Arguments info:
85+
- `--converted-model-path CONVERTED_MODEL_PATH`: argument defining the converted model path by [`convert.py`](./convert.py)
86+
- `--prompt PROMPT`: argument defining the prompt to be infered (with integrated prompt format for chat). It is default to be `'What is AI?'`.
87+
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
88+
89+
### 2.2 Sample Outputs
90+
91+
#### [moonshotai/Moonlight-16B-A3B-Instruct](https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct)
92+
93+
```log
94+
Inference time: xxxx s
95+
-------------------- Prompt --------------------
96+
Is 123 a prime?
97+
-------------------- Output --------------------
98+
<|im_system|>system<|im_middle|>You are a helpful assistant provided by Moonshot-AI.<|im_end|><|im_user|>user<|im_middle|>Is 123 a prime?<|im_end|><|im_assistant|>assistant<|im_middle|>No, 123 is not a prime number. A prime number is a number greater than 1 that has no positive divisors other than 1 and itself
99+
```
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#
2+
# Copyright 2016 The BigDL Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import os
18+
import shutil
19+
import argparse
20+
from safetensors.torch import load_file, save_file
21+
22+
if __name__ == '__main__':
23+
parser = argparse.ArgumentParser(description='Convert Moonlight model to be sucessfully loaded by transformers')
24+
parser.add_argument('--model-path', type=str, required=True,
25+
help='Path to the downloaded Moonlight model')
26+
27+
args = parser.parse_args()
28+
model_path = args.model_path
29+
converted_model_path = model_path + '-converted'
30+
31+
if os.path.exists(converted_model_path):
32+
shutil.rmtree(converted_model_path)
33+
34+
os.makedirs(converted_model_path)
35+
36+
for f in os.listdir(model_path):
37+
f_path = os.path.join(model_path, f)
38+
f_dst_path = os.path.join(converted_model_path, f)
39+
40+
if f.endswith(".safetensors"):
41+
save_file(load_file(f_path), f_dst_path, metadata={"format": "pt"})
42+
elif not f.startswith(".") and os.path.isfile(f_path): # skip dir and file name started with .
43+
shutil.copyfile(f_path, f_dst_path)
44+
45+
print(f"Converted model successfully saved to {converted_model_path}")
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#
2+
# Copyright 2016 The BigDL Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import argparse
18+
from huggingface_hub import snapshot_download
19+
20+
if __name__ == '__main__':
21+
parser = argparse.ArgumentParser(description='Download Moonlight model')
22+
parser.add_argument('--repo-id', type=str, default='moonshotai/Moonlight-16B-A3B-Instruct',
23+
help='Hugging Face repo id of the model to be downloaded')
24+
parser.add_argument('--commit-id', type=str, required=True,
25+
help='Revision of the model to be downloaded')
26+
parser.add_argument('--download-dir-path', type=str,
27+
help='Folder path where the model will be downloaded')
28+
29+
args = parser.parse_args()
30+
31+
repo_id = args.repo_id
32+
download_dir_path = args.download_dir_path
33+
if download_dir_path is None:
34+
download_dir_path = repo_id.rsplit("/", 1)[-1]
35+
36+
snapshot_download(repo_id=repo_id,
37+
revision=args.commit_id,
38+
local_dir=download_dir_path)
39+
40+
print(f'{repo_id} has been downloaded to {download_dir_path}')
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#
2+
# Copyright 2016 The BigDL Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import time
18+
import argparse
19+
import torch
20+
from ipex_llm.transformers import AutoModelForCausalLM
21+
from transformers import AutoTokenizer
22+
23+
if __name__ == '__main__':
24+
parser = argparse.ArgumentParser(description='Predict Tokens using `generate()` API for Moonlight model')
25+
parser.add_argument('--converted-model-path', type=str, required=True,
26+
help='Model path to the converted Moonlight model by convert.py')
27+
parser.add_argument('--prompt', type=str, default="What is AI?",
28+
help='Prompt to infer')
29+
parser.add_argument('--n-predict', type=int, default=32,
30+
help='Max tokens to predict')
31+
32+
args = parser.parse_args()
33+
converted_model_path = args.converted_model_path
34+
35+
# Load model in 4 bit,
36+
# which convert the relevant layers in the model into INT4 format
37+
model = AutoModelForCausalLM.from_pretrained(converted_model_path,
38+
load_in_4bit=True,
39+
optimize_model=True,
40+
trust_remote_code=True,
41+
use_cache=True)
42+
model = model.to('xpu')
43+
44+
# Load tokenizer
45+
tokenizer = AutoTokenizer.from_pretrained(converted_model_path, trust_remote_code=True)
46+
47+
# Generate predicted tokens
48+
with torch.inference_mode():
49+
# here the prompt tuning refers to
50+
# https://huggingface.co/moonshotai/Moonlight-16B-A3B-Instruct#inference-with-hugging-face-transformers
51+
messages = [
52+
{"role": "system", "content": "You are a helpful assistant provided by Moonshot-AI."},
53+
{"role": "user", "content": args.prompt}
54+
]
55+
input_ids = tokenizer.apply_chat_template(
56+
messages,
57+
add_generation_prompt=True,
58+
return_tensors="pt"
59+
).to('xpu')
60+
61+
# ipex_llm model needs a warmup, then inference time can be accurate
62+
output = model.generate(input_ids,
63+
max_new_tokens=args.n_predict)
64+
65+
# start inference
66+
st = time.time()
67+
output = model.generate(input_ids,
68+
max_new_tokens=args.n_predict)
69+
torch.xpu.synchronize()
70+
end = time.time()
71+
72+
output_str = tokenizer.decode(output[0], skip_special_tokens=False)
73+
print(f'Inference time: {end-st} s')
74+
print('-'*20, 'Prompt', '-'*20)
75+
print(args.prompt)
76+
print('-'*20, 'Output', '-'*20)
77+
print(output_str)

0 commit comments

Comments
 (0)