Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring OGA under test and fix OGA server. Improve llm-prompt. #272

Merged
merged 7 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions .github/workflows/test_lemonade_oga_cpu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Lint and Test Lemonade for OGA on CPU

on:
push:
branches: ["main"]
pull_request:
branches: ["main"]

permissions:
contents: read

jobs:
make-oga-cpu-lemonade:
env:
LEMONADE_CI_MODE: "True"
runs-on: windows-latest
steps:
- uses: actions/checkout@v3
- name: Set up Miniconda with 64-bit Python
uses: conda-incubator/setup-miniconda@v2
with:
miniconda-version: "latest"
activate-environment: lemon
python-version: "3.10"
run-post: "false"
- name: Install dependencies
shell: bash -el {0}
run: |
python -m pip install --upgrade pip
conda install pylint
python -m pip check
pip install -e .[llm-oga-cpu]
- name: Lint with Black
uses: psf/black@stable
with:
options: "--check --verbose"
src: "./src"
- name: Lint with PyLint
shell: bash -el {0}
run: |
pylint src/lemonade --rcfile .pylintrc --disable E0401
- name: Test OGA+CPU server
if: runner.os == 'Windows'
timeout-minutes: 10
uses: ./.github/actions/server-testing
with:
conda_env: -n lemon
load_command: -i Qwen/Qwen2.5-0.5B-Instruct oga-load --device cpu --dtype int4
hf_token: "${{ secrets.HUGGINGFACE_ACCESS_TOKEN }}" # Required by OGA model_builder in OGA 0.4.0 but not future versions
- name: Run lemonade tests
shell: bash -el {0}
env:
HF_TOKEN: "${{ secrets.HUGGINGFACE_ACCESS_TOKEN }}" # Required by OGA model_builder in OGA 0.4.0 but not future versions
run: |
lemonade -i Qwen/Qwen2.5-0.5B-Instruct oga-load --device cpu --dtype int4 llm-prompt -p "hi what is your name" --max-new-tokens 10
python test/oga_cpu_api.py

6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@
"fastapi",
"uvicorn[standard]",
],
"llm-oga-cpu": [
"onnxruntime-genai==0.5.2",
jeremyfowers marked this conversation as resolved.
Show resolved Hide resolved
"torch>=2.0.0,<2.4",
"transformers<4.45.0",
jeremyfowers marked this conversation as resolved.
Show resolved Hide resolved
"turnkeyml[llm]",
],
"llm-oga-igpu": [
"onnxruntime-genai-directml==0.4.0",
"torch>=2.0.0,<2.4",
Expand Down
3 changes: 3 additions & 0 deletions src/lemonade/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ class Keys:
STD_DEV_SECONDS_TO_FIRST_TOKEN = "std_dev_seconds_to_first_token"
CHECKPOINT = "checkpoint"
DTYPE = "dtype"
PROMPT = "prompt"
PROMPT_TOKENS = "prompt_tokens"
RESPONSE = "response"
RESPONSE_TOKENS = "response_tokens"
CACHE_DIR = "cache_dir"
DEVICE = "device"
OGA_MODELS_SUBFOLDER = "oga_models_subfolder"
30 changes: 27 additions & 3 deletions src/lemonade/tools/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from turnkeyml.state import State
from turnkeyml.tools import Tool
from lemonade.tools.adapter import ModelAdapter, TokenizerAdapter
from lemonade.cache import Keys

DEFAULT_GENERATE_PARAMS = {
"do_sample": True,
Expand Down Expand Up @@ -43,7 +44,12 @@ class LLMPrompt(Tool):
def __init__(self):
super().__init__(monitor_message="Prompting LLM")

self.status_stats = ["response"]
self.status_stats = [
Keys.PROMPT_TOKENS,
Keys.PROMPT,
Keys.RESPONSE_TOKENS,
Keys.RESPONSE,
]

@staticmethod
def parser(add_help: bool = True) -> argparse.ArgumentParser:
Expand Down Expand Up @@ -75,13 +81,31 @@ def run(
tokenizer: TokenizerAdapter = state.tokenizer

input_ids = tokenizer(prompt, return_tensors="pt").input_ids
jeremyfowers marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(input_ids, list):
# OGA models return a list of tokens
len_tokens_in = len(input_ids)
else:
# HF models return a 2-D tensor
len_tokens_in = input_ids.shape[1]

response = model.generate(
input_ids, max_new_tokens=max_new_tokens, **DEFAULT_GENERATE_PARAMS
)
response_text = tokenizer.decode(response[0], skip_special_tokens=True).strip()
len_tokens_out = len(response[0]) - len_tokens_in
input_ids = input_ids if isinstance(input_ids, list) else input_ids[0]
i = 0
while i < len_tokens_in and input_ids[i] == response[0][i]:
i += 1
response_text = tokenizer.decode(
response[0][i:], skip_special_tokens=True
).strip()

state.response = response_text
state.save_stat("response", response_text)

state.save_stat(Keys.PROMPT_TOKENS, len_tokens_in)
state.save_stat(Keys.PROMPT, prompt)
state.save_stat(Keys.RESPONSE_TOKENS, len_tokens_out)
state.save_stat(Keys.RESPONSE, response_text)

return state

Expand Down
5 changes: 4 additions & 1 deletion src/lemonade/tools/ort_genai/oga.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def generate(

max_length = len(input_ids) + max_new_tokens

params.input_ids = input_ids
jeremyfowers marked this conversation as resolved.
Show resolved Hide resolved
if self.config and "search" in self.config:
search_config = self.config["search"]
params.set_search_options(
Expand Down Expand Up @@ -158,10 +159,10 @@ def generate(
params.try_graph_capture_with_max_batch_size(1)

generator = og.Generator(self.model, params)
generator.append_tokens(input_ids)

if streamer is None:
prompt_start_time = time.perf_counter()
generator.compute_logits()
generator.generate_next_token()
prompt_end_time = time.perf_counter()

Expand All @@ -172,6 +173,7 @@ def generate(
token_gen_times = []
while not generator.is_done():
token_gen_start_time = time.perf_counter()
generator.compute_logits()
generator.generate_next_token()
token_gen_end_time = time.perf_counter()

Expand All @@ -192,6 +194,7 @@ def generate(
stop_early = False

while not generator.is_done() and not stop_early:
generator.compute_logits()
generator.generate_next_token()

new_token = generator.get_next_tokens()[0]
Expand Down
100 changes: 100 additions & 0 deletions test/oga_cpu_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import unittest
import shutil
import os
import urllib3
from turnkeyml.state import State
import turnkeyml.common.test_helpers as common
import turnkeyml.common.filesystem as fs
from lemonade.tools.ort_genai.oga import OgaLoad
from lemonade.tools.chat import LLMPrompt
from lemonade.tools.mmlu import AccuracyMMLU
from lemonade.tools.humaneval import AccuracyHumaneval

ci_mode = os.getenv("LEMONADE_CI_MODE", False)

checkpoint = "Qwen/Qwen2.5-0.5B-Instruct"
device = "cpu"
dtype = "int4"
force = False
prompt = "Alice and Bob"

try:
url = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
resp = urllib3.request("GET", url, preload_content=False)
if 200 <= resp.status < 400:
eecs_berkeley_edu_cannot_be_reached = False
else:
eecs_berkeley_edu_cannot_be_reached = True
resp.release_conn()
except urllib3.exceptions.HTTPError:
eecs_berkeley_edu_cannot_be_reached = True


class Testing(unittest.TestCase):

def setUp(self) -> None:
shutil.rmtree(cache_dir, ignore_errors=True)

def test_001_ogaload(self):
# Test the OgaLoad and LLMPrompt tools on an NPU model

state = State(cache_dir=cache_dir, build_name="test")

state = OgaLoad().run(
state, input=checkpoint, device=device, dtype=dtype, force=force
)
state = LLMPrompt().run(state, prompt=prompt, max_new_tokens=10)

assert len(state.response) > len(prompt), state.response

@unittest.skipIf(
eecs_berkeley_edu_cannot_be_reached,
"eecs.berkeley.edu cannot be reached for dataset download",
)
def test_002_accuracy_mmlu(self):
# Test MMLU benchmarking with known model
subject = ["management"]

state = State(
cache_dir=cache_dir,
build_name="test",
)

state = OgaLoad().run(state, input=checkpoint, device=device, dtype=dtype)
state = AccuracyMMLU().run(state, ntrain=5, tests=subject)

stats = fs.Stats(state.cache_dir, state.build_name).stats
assert stats[f"mmlu_{subject[0]}_accuracy"] > 0

def test_003_accuracy_humaneval(self):
"""Test HumanEval benchmarking with known model"""

state = State(
cache_dir=cache_dir,
build_name="test",
)

# Enable code evaluation for HumanEval
os.environ["HF_ALLOW_CODE_EVAL"] = "1"

state = OgaLoad().run(state, input=checkpoint, device=device, dtype=dtype)
state = AccuracyHumaneval().run(
state,
first_n_samples=1, # Test only one problem for speed
k_samples=1, # Single attempt per problem
timeout=30.0,
)

# Verify results
stats = fs.Stats(state.cache_dir, state.build_name).stats
assert "humaneval_pass@1" in stats, "HumanEval pass@1 metric not found"
assert isinstance(
stats["humaneval_pass@1"], (int, float)
), "HumanEval pass@1 metric should be numeric"


if __name__ == "__main__":
cache_dir, _ = common.create_test_dir(
"lemonade_oga_cpu_api", base_dir=os.path.abspath(".")
)
unittest.main()
Loading