13
13
import time
14
14
import json
15
15
import shutil
16
+ import numpy as np
16
17
from fnmatch import fnmatch
17
18
from queue import Queue
18
19
from packaging .version import Version
@@ -102,8 +103,10 @@ class OrtGenaiModel(ModelAdapter):
102
103
def __init__ (self , input_folder ):
103
104
super ().__init__ ()
104
105
self .model = og .Model (input_folder )
106
+ self .model_path = input_folder
105
107
self .type = "ort-genai"
106
108
self .config = self .load_config (input_folder )
109
+ self .tokenizer = og .Tokenizer (self .model )
107
110
108
111
def load_config (self , input_folder ):
109
112
config_path = os .path .join (input_folder , "genai_config.json" )
@@ -124,7 +127,43 @@ def generate(
124
127
streamer : OrtGenaiStreamer = None ,
125
128
pad_token_id = None ,
126
129
stopping_criteria = None ,
130
+ chat_template = "" ,
127
131
):
132
+
133
+ # Get model type
134
+ model_type = None
135
+ if hasattr (self .model , "type" ):
136
+ model_type = self .model .type
137
+ else :
138
+ import json , os
139
+
140
+ with open (os .path .join (self .model_path , "genai_config.json" ), "r" ) as f :
141
+ genai_config = json .load (f )
142
+ model_type = genai_config ["model" ]["type" ]
143
+
144
+ # Set chat template
145
+ if chat_template :
146
+ if chat_template .count ("{" ) != 1 or chat_template .count ("}" ) != 1 :
147
+ raise ValueError (
148
+ "Chat template must have exactly one pair of curly braces with input word in it, e.g. '<|user|>\n {input} <|end|>\n <|assistant|>'"
149
+ )
150
+ else :
151
+ if model_type .startswith ("phi2" ) or model_type .startswith ("phi3" ):
152
+ chat_template = "<|user|>\n {input} <|end|>\n <|assistant|>"
153
+ elif model_type .startswith ("phi4" ):
154
+ chat_template = "<|im_start|>user<|im_sep|>\n {input}<|im_end|>\n <|im_start|>assistant<|im_sep|>"
155
+ elif model_type .startswith ("llama3" ):
156
+ chat_template = "<|start_header_id|>user<|end_header_id|>\n {input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
157
+ elif model_type .startswith ("llama2" ):
158
+ chat_template = "<s>{input}"
159
+ elif model_type .startswith ("qwen2" ):
160
+ chat_template = (
161
+ "<|im_start|>user\n {input}<|im_end|>\n <|im_start|>assistant\n "
162
+ )
163
+ else :
164
+ raise ValueError (
165
+ f"Chat Template for model type { model_type } is not known. Please provide chat template using --chat_template"
166
+ )
128
167
params = og .GeneratorParams (self .model )
129
168
130
169
# There is a breaking API change in OGA 0.6.0
@@ -144,6 +183,13 @@ def generate(
144
183
if use_oga_pre_6_api :
145
184
params .input_ids = input_ids
146
185
186
+ if isinstance (input_ids , list ):
187
+ input_ids_np = np .array (input_ids , dtype = np .int32 )
188
+ else :
189
+ input_ids_np = input_ids .cpu ().numpy ().astype (np .int32 )
190
+
191
+ decoded_prompt = self .tokenizer .decode (input_ids_np )
192
+
147
193
if self .config and "search" in self .config :
148
194
search_config = self .config ["search" ]
149
195
params .set_search_options (
@@ -177,8 +223,13 @@ def generate(
177
223
params .try_graph_capture_with_max_batch_size (1 )
178
224
179
225
generator = og .Generator (self .model , params )
226
+ prompt = decoded_prompt
227
+ prompt = f"{ chat_template .format (input = decoded_prompt )} "
228
+
229
+ input_tokens = self .tokenizer .encode (prompt )
230
+
180
231
if use_oga_post_6_api :
181
- generator .append_tokens (input_ids )
232
+ generator .append_tokens (input_tokens )
182
233
183
234
if streamer is None :
184
235
prompt_start_time = time .perf_counter ()
0 commit comments