15
15
import shutil
16
16
from fnmatch import fnmatch
17
17
from queue import Queue
18
+ from packaging .version import Version
18
19
from huggingface_hub import snapshot_download
19
20
import onnxruntime_genai as og
20
21
import onnxruntime_genai .models .builder as model_builder
@@ -120,12 +121,19 @@ def generate(
120
121
):
121
122
params = og .GeneratorParams (self .model )
122
123
124
+ # There is a breaking API change in OGA 0.6.0
125
+ # Determine whether we should use the old or new APIs
126
+ use_oga_pre_6_api = Version (og .__version__ ) < Version ("0.6.0" )
127
+ use_oga_post_6_api = not use_oga_pre_6_api
128
+
123
129
if pad_token_id :
124
130
params .pad_token_id = pad_token_id
125
131
126
132
max_length = len (input_ids ) + max_new_tokens
127
133
128
- params .input_ids = input_ids
134
+ if use_oga_pre_6_api :
135
+ params .input_ids = input_ids
136
+
129
137
if self .config and "search" in self .config :
130
138
search_config = self .config ["search" ]
131
139
params .set_search_options (
@@ -159,10 +167,13 @@ def generate(
159
167
params .try_graph_capture_with_max_batch_size (1 )
160
168
161
169
generator = og .Generator (self .model , params )
170
+ if use_oga_post_6_api :
171
+ generator .append_tokens (input_ids )
162
172
163
173
if streamer is None :
164
174
prompt_start_time = time .perf_counter ()
165
- generator .compute_logits ()
175
+ if use_oga_pre_6_api :
176
+ generator .compute_logits ()
166
177
generator .generate_next_token ()
167
178
prompt_end_time = time .perf_counter ()
168
179
@@ -173,7 +184,8 @@ def generate(
173
184
token_gen_times = []
174
185
while not generator .is_done ():
175
186
token_gen_start_time = time .perf_counter ()
176
- generator .compute_logits ()
187
+ if use_oga_pre_6_api :
188
+ generator .compute_logits ()
177
189
generator .generate_next_token ()
178
190
token_gen_end_time = time .perf_counter ()
179
191
@@ -194,7 +206,8 @@ def generate(
194
206
stop_early = False
195
207
196
208
while not generator .is_done () and not stop_early :
197
- generator .compute_logits ()
209
+ if use_oga_pre_6_api :
210
+ generator .compute_logits ()
198
211
generator .generate_next_token ()
199
212
200
213
new_token = generator .get_next_tokens ()[0 ]
0 commit comments