Skip to content

Commit ac1a740

Browse files
committed
Better compatibility for OGA
1 parent b43c0e9 commit ac1a740

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
"uvicorn[standard]",
6868
],
6969
"llm-oga-cpu": [
70-
"onnxruntime-genai==0.5.2",
70+
"onnxruntime-genai>=0.5.2",
7171
"torch>=2.0.0,<2.4",
7272
"transformers<4.45.0",
7373
"turnkeyml[llm]",

Diff for: src/lemonade/tools/ort_genai/oga.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import shutil
1616
from fnmatch import fnmatch
1717
from queue import Queue
18+
from packaging.version import Version
1819
from huggingface_hub import snapshot_download
1920
import onnxruntime_genai as og
2021
import onnxruntime_genai.models.builder as model_builder
@@ -120,12 +121,19 @@ def generate(
120121
):
121122
params = og.GeneratorParams(self.model)
122123

124+
# There is a breaking API change in OGA 0.6.0
125+
# Determine whether we should use the old or new APIs
126+
use_oga_pre_6_api = Version(og.__version__) < Version("0.6.0")
127+
use_oga_post_6_api = not use_oga_pre_6_api
128+
123129
if pad_token_id:
124130
params.pad_token_id = pad_token_id
125131

126132
max_length = len(input_ids) + max_new_tokens
127133

128-
params.input_ids = input_ids
134+
if use_oga_pre_6_api:
135+
params.input_ids = input_ids
136+
129137
if self.config and "search" in self.config:
130138
search_config = self.config["search"]
131139
params.set_search_options(
@@ -159,10 +167,13 @@ def generate(
159167
params.try_graph_capture_with_max_batch_size(1)
160168

161169
generator = og.Generator(self.model, params)
170+
if use_oga_post_6_api:
171+
generator.append_tokens(input_ids)
162172

163173
if streamer is None:
164174
prompt_start_time = time.perf_counter()
165-
generator.compute_logits()
175+
if use_oga_pre_6_api:
176+
generator.compute_logits()
166177
generator.generate_next_token()
167178
prompt_end_time = time.perf_counter()
168179

@@ -173,7 +184,8 @@ def generate(
173184
token_gen_times = []
174185
while not generator.is_done():
175186
token_gen_start_time = time.perf_counter()
176-
generator.compute_logits()
187+
if use_oga_pre_6_api:
188+
generator.compute_logits()
177189
generator.generate_next_token()
178190
token_gen_end_time = time.perf_counter()
179191

@@ -194,7 +206,8 @@ def generate(
194206
stop_early = False
195207

196208
while not generator.is_done() and not stop_early:
197-
generator.compute_logits()
209+
if use_oga_pre_6_api:
210+
generator.compute_logits()
198211
generator.generate_next_token()
199212

200213
new_token = generator.get_next_tokens()[0]

0 commit comments

Comments
 (0)