@@ -125,7 +125,6 @@ def generate(
125
125
126
126
max_length = len (input_ids ) + max_new_tokens
127
127
128
- params .input_ids = input_ids
129
128
if self .config and "search" in self .config :
130
129
search_config = self .config ["search" ]
131
130
params .set_search_options (
@@ -159,10 +158,10 @@ def generate(
159
158
params .try_graph_capture_with_max_batch_size (1 )
160
159
161
160
generator = og .Generator (self .model , params )
161
+ generator .append_tokens (input_ids )
162
162
163
163
if streamer is None :
164
164
prompt_start_time = time .perf_counter ()
165
- generator .compute_logits ()
166
165
generator .generate_next_token ()
167
166
prompt_end_time = time .perf_counter ()
168
167
@@ -173,7 +172,6 @@ def generate(
173
172
token_gen_times = []
174
173
while not generator .is_done ():
175
174
token_gen_start_time = time .perf_counter ()
176
- generator .compute_logits ()
177
175
generator .generate_next_token ()
178
176
token_gen_end_time = time .perf_counter ()
179
177
@@ -194,7 +192,6 @@ def generate(
194
192
stop_early = False
195
193
196
194
while not generator .is_done () and not stop_early :
197
- generator .compute_logits ()
198
195
generator .generate_next_token ()
199
196
200
197
new_token = generator .get_next_tokens ()[0 ]
@@ -253,6 +250,13 @@ def parser(add_help: bool = True) -> argparse.ArgumentParser:
253
250
add_help = add_help ,
254
251
)
255
252
253
+ parser .add_argument (
254
+ "-ip" ,
255
+ "--input_path" ,
256
+ default = "" ,
257
+ help = "the local huggingface model in your disk" ,
258
+ )
259
+
256
260
parser .add_argument (
257
261
"-d" ,
258
262
"--device" ,
@@ -304,6 +308,7 @@ def run(
304
308
self ,
305
309
state : State ,
306
310
input : str ,
311
+ input_path : str = "" ,
307
312
device : str = "igpu" ,
308
313
dtype : str = "int4" ,
309
314
int4_block_size : int = None ,
@@ -449,7 +454,7 @@ def run(
449
454
try :
450
455
model_builder .create_model (
451
456
checkpoint , # model_name
452
- "" , # input_path
457
+ input_path , # input_path
453
458
full_model_path , # output_path
454
459
dtype , # precision
455
460
execution_providers [device ], # execution_provider
0 commit comments