Skip to content

Commit 6adaa89

Browse files
committed
Fix the gibberish output from llm-prompt
Signed-off-by: Akshay Sonawane <[email protected]>
1 parent 4f3be13 commit 6adaa89

File tree

1 file changed

+52
-1
lines changed
  • src/lemonade/tools/ort_genai

1 file changed

+52
-1
lines changed

Diff for: src/lemonade/tools/ort_genai/oga.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import time
1414
import json
1515
import shutil
16+
import numpy as np
1617
from fnmatch import fnmatch
1718
from queue import Queue
1819
from packaging.version import Version
@@ -102,8 +103,10 @@ class OrtGenaiModel(ModelAdapter):
102103
def __init__(self, input_folder):
103104
super().__init__()
104105
self.model = og.Model(input_folder)
106+
self.model_path = input_folder
105107
self.type = "ort-genai"
106108
self.config = self.load_config(input_folder)
109+
self.tokenizer = og.Tokenizer(self.model)
107110

108111
def load_config(self, input_folder):
109112
config_path = os.path.join(input_folder, "genai_config.json")
@@ -124,7 +127,43 @@ def generate(
124127
streamer: OrtGenaiStreamer = None,
125128
pad_token_id=None,
126129
stopping_criteria=None,
130+
chat_template="",
127131
):
132+
133+
# Get model type
134+
model_type = None
135+
if hasattr(self.model, "type"):
136+
model_type = self.model.type
137+
else:
138+
import json, os
139+
140+
with open(os.path.join(self.model_path, "genai_config.json"), "r") as f:
141+
genai_config = json.load(f)
142+
model_type = genai_config["model"]["type"]
143+
144+
# Set chat template
145+
if chat_template:
146+
if chat_template.count("{") != 1 or chat_template.count("}") != 1:
147+
raise ValueError(
148+
"Chat template must have exactly one pair of curly braces with input word in it, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'"
149+
)
150+
else:
151+
if model_type.startswith("phi2") or model_type.startswith("phi3"):
152+
chat_template = "<|user|>\n{input} <|end|>\n<|assistant|>"
153+
elif model_type.startswith("phi4"):
154+
chat_template = "<|im_start|>user<|im_sep|>\n{input}<|im_end|>\n<|im_start|>assistant<|im_sep|>"
155+
elif model_type.startswith("llama3"):
156+
chat_template = "<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
157+
elif model_type.startswith("llama2"):
158+
chat_template = "<s>{input}"
159+
elif model_type.startswith("qwen2"):
160+
chat_template = (
161+
"<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n"
162+
)
163+
else:
164+
raise ValueError(
165+
f"Chat Template for model type {model_type} is not known. Please provide chat template using --chat_template"
166+
)
128167
params = og.GeneratorParams(self.model)
129168

130169
# There is a breaking API change in OGA 0.6.0
@@ -144,6 +183,13 @@ def generate(
144183
if use_oga_pre_6_api:
145184
params.input_ids = input_ids
146185

186+
if isinstance(input_ids, list):
187+
input_ids_np = np.array(input_ids, dtype=np.int32)
188+
else:
189+
input_ids_np = input_ids.cpu().numpy().astype(np.int32)
190+
191+
decoded_prompt = self.tokenizer.decode(input_ids_np)
192+
147193
if self.config and "search" in self.config:
148194
search_config = self.config["search"]
149195
params.set_search_options(
@@ -177,8 +223,13 @@ def generate(
177223
params.try_graph_capture_with_max_batch_size(1)
178224

179225
generator = og.Generator(self.model, params)
226+
prompt = decoded_prompt
227+
prompt = f"{chat_template.format(input=decoded_prompt)}"
228+
229+
input_tokens = self.tokenizer.encode(prompt)
230+
180231
if use_oga_post_6_api:
181-
generator.append_tokens(input_ids)
232+
generator.append_tokens(input_tokens)
182233

183234
if streamer is None:
184235
prompt_start_time = time.perf_counter()

0 commit comments

Comments
 (0)