Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions .github/workflows/test_lemonade_oga_cpu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Lint and Test Lemonade for OGA on CPU

on:
push:
branches: ["main"]
pull_request:
branches: ["main"]

permissions:
contents: read

jobs:
make-oga-cpu-lemonade:
env:
LEMONADE_CI_MODE: "True"
runs-on: windows-latest
steps:
- uses: actions/checkout@v3
- name: Set up Miniconda with 64-bit Python
uses: conda-incubator/setup-miniconda@v2
with:
miniconda-version: "latest"
activate-environment: lemon
python-version: "3.10"
run-post: "false"
- name: Install dependencies
shell: bash -el {0}
run: |
python -m pip install --upgrade pip
conda install pylint
python -m pip check
pip install -e .[llm-oga-cpu]
- name: Lint with Black
uses: psf/black@stable
with:
options: "--check --verbose"
src: "./src"
- name: Lint with PyLint
shell: bash -el {0}
run: |
pylint src/lemonade --rcfile .pylintrc --disable E0401
- name: Test OGA+CPU server
if: runner.os == 'Windows'
timeout-minutes: 10
uses: ./.github/actions/server-testing
with:
conda_env: -n lemon
load_command: -i TinyPixel/small-llama2 oga-load --device cpu --dtype int4
hf_token: "${{ secrets.HUGGINGFACE_ACCESS_TOKEN }}" # Required by OGA model_builder in OGA 0.4.0 but not future versions
- name: Run lemonade tests
shell: bash -el {0}
env:
HF_TOKEN: "${{ secrets.HUGGINGFACE_ACCESS_TOKEN }}" # Required by OGA model_builder in OGA 0.4.0 but not future versions
run: |
lemonade -i TinyPixel/small-llama2 oga-load --device cpu --dtype int4 llm-prompt -p "tell me a story" --max-new-tokens 5
python test/oga_cpu_api.py

6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@
"fastapi",
"uvicorn[standard]",
],
"llm-oga-cpu": [
"onnxruntime-genai>=0.5.2",
"torch>=2.0.0,<2.4",
"transformers<4.45.0",
"turnkeyml[llm]",
],
"llm-oga-igpu": [
"onnxruntime-genai-directml==0.4.0",
"torch>=2.0.0,<2.4",
Expand Down
3 changes: 3 additions & 0 deletions src/lemonade/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ class Keys:
STD_DEV_SECONDS_TO_FIRST_TOKEN = "std_dev_seconds_to_first_token"
CHECKPOINT = "checkpoint"
DTYPE = "dtype"
PROMPT = "prompt"
PROMPT_TOKENS = "prompt_tokens"
RESPONSE = "response"
RESPONSE_TOKENS = "response_tokens"
CACHE_DIR = "cache_dir"
DEVICE = "device"
OGA_MODELS_SUBFOLDER = "oga_models_subfolder"
34 changes: 31 additions & 3 deletions src/lemonade/tools/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from turnkeyml.state import State
from turnkeyml.tools import Tool
from lemonade.tools.adapter import ModelAdapter, TokenizerAdapter
from lemonade.cache import Keys

DEFAULT_GENERATE_PARAMS = {
"do_sample": True,
Expand All @@ -25,6 +26,10 @@
END_OF_STREAM = "</s>"


def sanitize_string(input_string):
return input_string.encode("utf-8", "ignore").decode("utf-8")


class LLMPrompt(Tool):
"""
Send a prompt to an LLM instance and print the response to the screen.
Expand All @@ -43,7 +48,12 @@ class LLMPrompt(Tool):
def __init__(self):
super().__init__(monitor_message="Prompting LLM")

self.status_stats = ["response"]
self.status_stats = [
Keys.PROMPT_TOKENS,
Keys.PROMPT,
Keys.RESPONSE_TOKENS,
Keys.RESPONSE,
]

@staticmethod
def parser(add_help: bool = True) -> argparse.ArgumentParser:
Expand Down Expand Up @@ -75,13 +85,31 @@ def run(
tokenizer: TokenizerAdapter = state.tokenizer

input_ids = tokenizer(prompt, return_tensors="pt").input_ids
if isinstance(input_ids, list):
# OGA models return a list of tokens
len_tokens_in = len(input_ids)
else:
# HF models return a 2-D tensor
len_tokens_in = input_ids.shape[1]

response = model.generate(
input_ids, max_new_tokens=max_new_tokens, **DEFAULT_GENERATE_PARAMS
)
response_text = tokenizer.decode(response[0], skip_special_tokens=True).strip()
len_tokens_out = len(response[0]) - len_tokens_in
input_ids = input_ids if isinstance(input_ids, list) else input_ids[0]
i = 0
while i < len_tokens_in and input_ids[i] == response[0][i]:
i += 1
response_text = tokenizer.decode(
response[0][i:], skip_special_tokens=True
).strip()

state.response = response_text
state.save_stat("response", response_text)

state.save_stat(Keys.PROMPT_TOKENS, len_tokens_in)
state.save_stat(Keys.PROMPT, prompt)
state.save_stat(Keys.RESPONSE_TOKENS, len_tokens_out)
state.save_stat(Keys.RESPONSE, sanitize_string(response_text))

return state

Expand Down
21 changes: 20 additions & 1 deletion src/lemonade/tools/ort_genai/oga.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import shutil
from fnmatch import fnmatch
from queue import Queue
from packaging.version import Version
from huggingface_hub import snapshot_download
import onnxruntime_genai as og
import onnxruntime_genai.models.builder as model_builder
Expand Down Expand Up @@ -120,11 +121,22 @@ def generate(
):
params = og.GeneratorParams(self.model)

# There is a breaking API change in OGA 0.6.0
# Determine whether we should use the old or new APIs
# This also supports 0.6.0.dev0, which evaluates to less than 0.6.0 in Version
use_oga_post_6_api = (
Version(og.__version__) >= Version("0.6.0") or "0.6.0" in og.__version__
)
use_oga_pre_6_api = not use_oga_post_6_api

if pad_token_id:
params.pad_token_id = pad_token_id

max_length = len(input_ids) + max_new_tokens

if use_oga_pre_6_api:
params.input_ids = input_ids

if self.config and "search" in self.config:
search_config = self.config["search"]
params.set_search_options(
Expand Down Expand Up @@ -158,10 +170,13 @@ def generate(
params.try_graph_capture_with_max_batch_size(1)

generator = og.Generator(self.model, params)
generator.append_tokens(input_ids)
if use_oga_post_6_api:
generator.append_tokens(input_ids)

if streamer is None:
prompt_start_time = time.perf_counter()
if use_oga_pre_6_api:
generator.compute_logits()
generator.generate_next_token()
prompt_end_time = time.perf_counter()

Expand All @@ -172,6 +187,8 @@ def generate(
token_gen_times = []
while not generator.is_done():
token_gen_start_time = time.perf_counter()
if use_oga_pre_6_api:
generator.compute_logits()
generator.generate_next_token()
token_gen_end_time = time.perf_counter()

Expand All @@ -192,6 +209,8 @@ def generate(
stop_early = False

while not generator.is_done() and not stop_early:
if use_oga_pre_6_api:
generator.compute_logits()
generator.generate_next_token()

new_token = generator.get_next_tokens()[0]
Expand Down
100 changes: 100 additions & 0 deletions test/oga_cpu_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import unittest
import shutil
import os
import urllib3
from turnkeyml.state import State
import turnkeyml.common.test_helpers as common
import turnkeyml.common.filesystem as fs
from lemonade.tools.ort_genai.oga import OgaLoad
from lemonade.tools.chat import LLMPrompt
from lemonade.tools.mmlu import AccuracyMMLU
from lemonade.tools.humaneval import AccuracyHumaneval

ci_mode = os.getenv("LEMONADE_CI_MODE", False)

checkpoint = "TinyPixel/small-llama2"
device = "cpu"
dtype = "int4"
force = False
prompt = "Alice and Bob"

try:
url = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
resp = urllib3.request("GET", url, preload_content=False)
if 200 <= resp.status < 400:
eecs_berkeley_edu_cannot_be_reached = False
else:
eecs_berkeley_edu_cannot_be_reached = True
resp.release_conn()
except urllib3.exceptions.HTTPError:
eecs_berkeley_edu_cannot_be_reached = True


class Testing(unittest.TestCase):

def setUp(self) -> None:
shutil.rmtree(cache_dir, ignore_errors=True)

def test_001_ogaload(self):
# Test the OgaLoad and LLMPrompt tools on an NPU model

state = State(cache_dir=cache_dir, build_name="test")

state = OgaLoad().run(
state, input=checkpoint, device=device, dtype=dtype, force=force
)
state = LLMPrompt().run(state, prompt=prompt, max_new_tokens=5)

assert len(state.response) > len(prompt), state.response

@unittest.skipIf(
eecs_berkeley_edu_cannot_be_reached,
"eecs.berkeley.edu cannot be reached for dataset download",
)
def test_002_accuracy_mmlu(self):
# Test MMLU benchmarking with known model
subject = ["management"]

state = State(
cache_dir=cache_dir,
build_name="test",
)

state = OgaLoad().run(state, input=checkpoint, device=device, dtype=dtype)
state = AccuracyMMLU().run(state, ntrain=5, tests=subject)

stats = fs.Stats(state.cache_dir, state.build_name).stats
assert stats[f"mmlu_{subject[0]}_accuracy"] >= 0

def test_003_accuracy_humaneval(self):
"""Test HumanEval benchmarking with known model"""

state = State(
cache_dir=cache_dir,
build_name="test",
)

# Enable code evaluation for HumanEval
os.environ["HF_ALLOW_CODE_EVAL"] = "1"

state = OgaLoad().run(state, input=checkpoint, device=device, dtype=dtype)
state = AccuracyHumaneval().run(
state,
first_n_samples=1, # Test only one problem for speed
k_samples=1, # Single attempt per problem
timeout=30.0,
)

# Verify results
stats = fs.Stats(state.cache_dir, state.build_name).stats
assert "humaneval_pass@1" in stats, "HumanEval pass@1 metric not found"
assert isinstance(
stats["humaneval_pass@1"], (int, float)
), "HumanEval pass@1 metric should be numeric"


if __name__ == "__main__":
cache_dir, _ = common.create_test_dir(
"lemonade_oga_cpu_api", base_dir=os.path.abspath(".")
)
unittest.main()
Loading