Skip to content

Commit c1463b0

Browse files
jeremyfowersapsonawaneamd-pworfolk
authored
Bring OGA under test and fix OGA server. Improve llm-prompt. (#272)
Co-authored-by: Akshay Sonawane <[email protected]> Co-authored-by: amd-pworfolk <[email protected]>
1 parent bc33e79 commit c1463b0

File tree

6 files changed

+220
-4
lines changed

6 files changed

+220
-4
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# This workflow will install Python dependencies, run tests and lint with a single version of Python
2+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3+
4+
name: Lint and Test Lemonade for OGA on CPU
5+
6+
on:
7+
push:
8+
branches: ["main"]
9+
pull_request:
10+
branches: ["main"]
11+
12+
permissions:
13+
contents: read
14+
15+
jobs:
16+
make-oga-cpu-lemonade:
17+
env:
18+
LEMONADE_CI_MODE: "True"
19+
runs-on: windows-latest
20+
steps:
21+
- uses: actions/checkout@v3
22+
- name: Set up Miniconda with 64-bit Python
23+
uses: conda-incubator/setup-miniconda@v2
24+
with:
25+
miniconda-version: "latest"
26+
activate-environment: lemon
27+
python-version: "3.10"
28+
run-post: "false"
29+
- name: Install dependencies
30+
shell: bash -el {0}
31+
run: |
32+
python -m pip install --upgrade pip
33+
conda install pylint
34+
python -m pip check
35+
pip install -e .[llm-oga-cpu]
36+
- name: Lint with Black
37+
uses: psf/black@stable
38+
with:
39+
options: "--check --verbose"
40+
src: "./src"
41+
- name: Lint with PyLint
42+
shell: bash -el {0}
43+
run: |
44+
pylint src/lemonade --rcfile .pylintrc --disable E0401
45+
- name: Test OGA+CPU server
46+
if: runner.os == 'Windows'
47+
timeout-minutes: 10
48+
uses: ./.github/actions/server-testing
49+
with:
50+
conda_env: -n lemon
51+
load_command: -i TinyPixel/small-llama2 oga-load --device cpu --dtype int4
52+
hf_token: "${{ secrets.HUGGINGFACE_ACCESS_TOKEN }}" # Required by OGA model_builder in OGA 0.4.0 but not future versions
53+
- name: Run lemonade tests
54+
shell: bash -el {0}
55+
env:
56+
HF_TOKEN: "${{ secrets.HUGGINGFACE_ACCESS_TOKEN }}" # Required by OGA model_builder in OGA 0.4.0 but not future versions
57+
run: |
58+
lemonade -i TinyPixel/small-llama2 oga-load --device cpu --dtype int4 llm-prompt -p "tell me a story" --max-new-tokens 5
59+
python test/oga_cpu_api.py
60+

setup.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@
6666
"fastapi",
6767
"uvicorn[standard]",
6868
],
69+
"llm-oga-cpu": [
70+
"onnxruntime-genai>=0.5.2",
71+
"torch>=2.0.0,<2.4",
72+
"transformers<4.45.0",
73+
"turnkeyml[llm]",
74+
],
6975
"llm-oga-igpu": [
7076
"onnxruntime-genai-directml==0.4.0",
7177
"torch>=2.0.0,<2.4",

src/lemonade/cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ class Keys:
2828
STD_DEV_SECONDS_TO_FIRST_TOKEN = "std_dev_seconds_to_first_token"
2929
CHECKPOINT = "checkpoint"
3030
DTYPE = "dtype"
31+
PROMPT = "prompt"
3132
PROMPT_TOKENS = "prompt_tokens"
33+
RESPONSE = "response"
34+
RESPONSE_TOKENS = "response_tokens"
3235
CACHE_DIR = "cache_dir"
3336
DEVICE = "device"
3437
OGA_MODELS_SUBFOLDER = "oga_models_subfolder"

src/lemonade/tools/chat.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from turnkeyml.state import State
1313
from turnkeyml.tools import Tool
1414
from lemonade.tools.adapter import ModelAdapter, TokenizerAdapter
15+
from lemonade.cache import Keys
1516

1617
DEFAULT_GENERATE_PARAMS = {
1718
"do_sample": True,
@@ -25,6 +26,10 @@
2526
END_OF_STREAM = "</s>"
2627

2728

29+
def sanitize_string(input_string):
30+
return input_string.encode("utf-8", "ignore").decode("utf-8")
31+
32+
2833
class LLMPrompt(Tool):
2934
"""
3035
Send a prompt to an LLM instance and print the response to the screen.
@@ -43,7 +48,12 @@ class LLMPrompt(Tool):
4348
def __init__(self):
4449
super().__init__(monitor_message="Prompting LLM")
4550

46-
self.status_stats = ["response"]
51+
self.status_stats = [
52+
Keys.PROMPT_TOKENS,
53+
Keys.PROMPT,
54+
Keys.RESPONSE_TOKENS,
55+
Keys.RESPONSE,
56+
]
4757

4858
@staticmethod
4959
def parser(add_help: bool = True) -> argparse.ArgumentParser:
@@ -75,13 +85,31 @@ def run(
7585
tokenizer: TokenizerAdapter = state.tokenizer
7686

7787
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
88+
if isinstance(input_ids, list):
89+
# OGA models return a list of tokens
90+
len_tokens_in = len(input_ids)
91+
else:
92+
# HF models return a 2-D tensor
93+
len_tokens_in = input_ids.shape[1]
94+
7895
response = model.generate(
7996
input_ids, max_new_tokens=max_new_tokens, **DEFAULT_GENERATE_PARAMS
8097
)
81-
response_text = tokenizer.decode(response[0], skip_special_tokens=True).strip()
98+
len_tokens_out = len(response[0]) - len_tokens_in
99+
input_ids = input_ids if isinstance(input_ids, list) else input_ids[0]
100+
i = 0
101+
while i < len_tokens_in and input_ids[i] == response[0][i]:
102+
i += 1
103+
response_text = tokenizer.decode(
104+
response[0][i:], skip_special_tokens=True
105+
).strip()
82106

83107
state.response = response_text
84-
state.save_stat("response", response_text)
108+
109+
state.save_stat(Keys.PROMPT_TOKENS, len_tokens_in)
110+
state.save_stat(Keys.PROMPT, prompt)
111+
state.save_stat(Keys.RESPONSE_TOKENS, len_tokens_out)
112+
state.save_stat(Keys.RESPONSE, sanitize_string(response_text))
85113

86114
return state
87115

src/lemonade/tools/ort_genai/oga.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import shutil
1616
from fnmatch import fnmatch
1717
from queue import Queue
18+
from packaging.version import Version
1819
from huggingface_hub import snapshot_download
1920
import onnxruntime_genai as og
2021
import onnxruntime_genai.models.builder as model_builder
@@ -120,11 +121,22 @@ def generate(
120121
):
121122
params = og.GeneratorParams(self.model)
122123

124+
# There is a breaking API change in OGA 0.6.0
125+
# Determine whether we should use the old or new APIs
126+
# This also supports 0.6.0.dev0, which evaluates to less than 0.6.0 in Version
127+
use_oga_post_6_api = (
128+
Version(og.__version__) >= Version("0.6.0") or "0.6.0" in og.__version__
129+
)
130+
use_oga_pre_6_api = not use_oga_post_6_api
131+
123132
if pad_token_id:
124133
params.pad_token_id = pad_token_id
125134

126135
max_length = len(input_ids) + max_new_tokens
127136

137+
if use_oga_pre_6_api:
138+
params.input_ids = input_ids
139+
128140
if self.config and "search" in self.config:
129141
search_config = self.config["search"]
130142
params.set_search_options(
@@ -158,10 +170,13 @@ def generate(
158170
params.try_graph_capture_with_max_batch_size(1)
159171

160172
generator = og.Generator(self.model, params)
161-
generator.append_tokens(input_ids)
173+
if use_oga_post_6_api:
174+
generator.append_tokens(input_ids)
162175

163176
if streamer is None:
164177
prompt_start_time = time.perf_counter()
178+
if use_oga_pre_6_api:
179+
generator.compute_logits()
165180
generator.generate_next_token()
166181
prompt_end_time = time.perf_counter()
167182

@@ -172,6 +187,8 @@ def generate(
172187
token_gen_times = []
173188
while not generator.is_done():
174189
token_gen_start_time = time.perf_counter()
190+
if use_oga_pre_6_api:
191+
generator.compute_logits()
175192
generator.generate_next_token()
176193
token_gen_end_time = time.perf_counter()
177194

@@ -192,6 +209,8 @@ def generate(
192209
stop_early = False
193210

194211
while not generator.is_done() and not stop_early:
212+
if use_oga_pre_6_api:
213+
generator.compute_logits()
195214
generator.generate_next_token()
196215

197216
new_token = generator.get_next_tokens()[0]

test/oga_cpu_api.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import unittest
2+
import shutil
3+
import os
4+
import urllib3
5+
from turnkeyml.state import State
6+
import turnkeyml.common.test_helpers as common
7+
import turnkeyml.common.filesystem as fs
8+
from lemonade.tools.ort_genai.oga import OgaLoad
9+
from lemonade.tools.chat import LLMPrompt
10+
from lemonade.tools.mmlu import AccuracyMMLU
11+
from lemonade.tools.humaneval import AccuracyHumaneval
12+
13+
ci_mode = os.getenv("LEMONADE_CI_MODE", False)
14+
15+
checkpoint = "TinyPixel/small-llama2"
16+
device = "cpu"
17+
dtype = "int4"
18+
force = False
19+
prompt = "Alice and Bob"
20+
21+
try:
22+
url = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
23+
resp = urllib3.request("GET", url, preload_content=False)
24+
if 200 <= resp.status < 400:
25+
eecs_berkeley_edu_cannot_be_reached = False
26+
else:
27+
eecs_berkeley_edu_cannot_be_reached = True
28+
resp.release_conn()
29+
except urllib3.exceptions.HTTPError:
30+
eecs_berkeley_edu_cannot_be_reached = True
31+
32+
33+
class Testing(unittest.TestCase):
34+
35+
def setUp(self) -> None:
36+
shutil.rmtree(cache_dir, ignore_errors=True)
37+
38+
def test_001_ogaload(self):
39+
# Test the OgaLoad and LLMPrompt tools on an NPU model
40+
41+
state = State(cache_dir=cache_dir, build_name="test")
42+
43+
state = OgaLoad().run(
44+
state, input=checkpoint, device=device, dtype=dtype, force=force
45+
)
46+
state = LLMPrompt().run(state, prompt=prompt, max_new_tokens=5)
47+
48+
assert len(state.response) > len(prompt), state.response
49+
50+
@unittest.skipIf(
51+
eecs_berkeley_edu_cannot_be_reached,
52+
"eecs.berkeley.edu cannot be reached for dataset download",
53+
)
54+
def test_002_accuracy_mmlu(self):
55+
# Test MMLU benchmarking with known model
56+
subject = ["management"]
57+
58+
state = State(
59+
cache_dir=cache_dir,
60+
build_name="test",
61+
)
62+
63+
state = OgaLoad().run(state, input=checkpoint, device=device, dtype=dtype)
64+
state = AccuracyMMLU().run(state, ntrain=5, tests=subject)
65+
66+
stats = fs.Stats(state.cache_dir, state.build_name).stats
67+
assert stats[f"mmlu_{subject[0]}_accuracy"] >= 0
68+
69+
def test_003_accuracy_humaneval(self):
70+
"""Test HumanEval benchmarking with known model"""
71+
72+
state = State(
73+
cache_dir=cache_dir,
74+
build_name="test",
75+
)
76+
77+
# Enable code evaluation for HumanEval
78+
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
79+
80+
state = OgaLoad().run(state, input=checkpoint, device=device, dtype=dtype)
81+
state = AccuracyHumaneval().run(
82+
state,
83+
first_n_samples=1, # Test only one problem for speed
84+
k_samples=1, # Single attempt per problem
85+
timeout=30.0,
86+
)
87+
88+
# Verify results
89+
stats = fs.Stats(state.cache_dir, state.build_name).stats
90+
assert "humaneval_pass@1" in stats, "HumanEval pass@1 metric not found"
91+
assert isinstance(
92+
stats["humaneval_pass@1"], (int, float)
93+
), "HumanEval pass@1 metric should be numeric"
94+
95+
96+
if __name__ == "__main__":
97+
cache_dir, _ = common.create_test_dir(
98+
"lemonade_oga_cpu_api", base_dir=os.path.abspath(".")
99+
)
100+
unittest.main()

0 commit comments

Comments
 (0)