Skip to content

Commit c107c2a

Browse files
Minor changes fixing onnxruntime_genai issue and input_path (#267)
Signed-off-by: Akshay Sonawane <[email protected]> Co-authored-by: Ramakrishnan Sivakumar <[email protected]>
1 parent 0a58235 commit c107c2a

File tree

1 file changed

+10
-5
lines changed
  • src/lemonade/tools/ort_genai

1 file changed

+10
-5
lines changed

src/lemonade/tools/ort_genai/oga.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def generate(
125125

126126
max_length = len(input_ids) + max_new_tokens
127127

128-
params.input_ids = input_ids
129128
if self.config and "search" in self.config:
130129
search_config = self.config["search"]
131130
params.set_search_options(
@@ -159,10 +158,10 @@ def generate(
159158
params.try_graph_capture_with_max_batch_size(1)
160159

161160
generator = og.Generator(self.model, params)
161+
generator.append_tokens(input_ids)
162162

163163
if streamer is None:
164164
prompt_start_time = time.perf_counter()
165-
generator.compute_logits()
166165
generator.generate_next_token()
167166
prompt_end_time = time.perf_counter()
168167

@@ -173,7 +172,6 @@ def generate(
173172
token_gen_times = []
174173
while not generator.is_done():
175174
token_gen_start_time = time.perf_counter()
176-
generator.compute_logits()
177175
generator.generate_next_token()
178176
token_gen_end_time = time.perf_counter()
179177

@@ -194,7 +192,6 @@ def generate(
194192
stop_early = False
195193

196194
while not generator.is_done() and not stop_early:
197-
generator.compute_logits()
198195
generator.generate_next_token()
199196

200197
new_token = generator.get_next_tokens()[0]
@@ -253,6 +250,13 @@ def parser(add_help: bool = True) -> argparse.ArgumentParser:
253250
add_help=add_help,
254251
)
255252

253+
parser.add_argument(
254+
"-ip",
255+
"--input_path",
256+
default="",
257+
help="the local huggingface model in your disk",
258+
)
259+
256260
parser.add_argument(
257261
"-d",
258262
"--device",
@@ -304,6 +308,7 @@ def run(
304308
self,
305309
state: State,
306310
input: str,
311+
input_path: str = "",
307312
device: str = "igpu",
308313
dtype: str = "int4",
309314
int4_block_size: int = None,
@@ -449,7 +454,7 @@ def run(
449454
try:
450455
model_builder.create_model(
451456
checkpoint, # model_name
452-
"", # input_path
457+
input_path, # input_path
453458
full_model_path, # output_path
454459
dtype, # precision
455460
execution_providers[device], # execution_provider

0 commit comments

Comments
 (0)