7
7
import sys
8
8
import time
9
9
from pathlib import Path
10
- from typing import Optional , Tuple
10
+ from typing import Optional , Tuple , Union
11
11
12
12
import torch
13
13
import torch ._dynamo .config
@@ -25,7 +25,9 @@ def device_sync(device):
25
25
26
26
torch ._inductor .config .coordinate_descent_tuning = True
27
27
torch ._inductor .config .triton .unique_kernel_names = True
28
- torch ._inductor .config .fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
28
+ # Experimental features to reduce compilation times, will be on by default in future
29
+ torch ._inductor .config .fx_graph_cache = True
30
+ torch ._functorch .config .enable_autograd_cache = True
29
31
30
32
default_device = 'cuda' if torch .cuda .is_available () else 'cpu'
31
33
@@ -51,7 +53,7 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
51
53
return probs
52
54
53
55
def sample (logits , temperature : float = 1.0 , top_k : Optional [int ] = None ):
54
- probs = logits_to_probs (logits [0 , - 1 ], temperature , top_k )
56
+ probs = logits_to_probs (logits [: , - 1 ], temperature , top_k )
55
57
idx_next = multinomial_sample_one_no_sync (probs )
56
58
return idx_next , probs
57
59
@@ -86,7 +88,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
86
88
new_tokens .append (next_token .clone ())
87
89
callback (new_tokens [- 1 ])
88
90
new_probs .append (next_prob .clone ())
89
- cur_token = next_token .view ( 1 , - 1 )
91
+ cur_token = next_token .clone ( )
90
92
91
93
return new_tokens , new_probs
92
94
@@ -149,6 +151,7 @@ def generate(
149
151
model : Transformer ,
150
152
prompt : torch .Tensor ,
151
153
max_new_tokens : int ,
154
+ batch_size : int ,
152
155
* ,
153
156
interactive : bool ,
154
157
draft_model : Transformer ,
@@ -162,7 +165,7 @@ def generate(
162
165
163
166
is_speculative = draft_model is not None
164
167
# create an empty tensor of the expected final shape and fill in the current tokens
165
- T = prompt .size (0 )
168
+ T = prompt .size (- 1 )
166
169
T_new = T + max_new_tokens
167
170
T_buf = roundup (T_new , 128 ) # round up to multiple of 128 to use flex_attention
168
171
if interactive :
@@ -173,20 +176,22 @@ def generate(
173
176
device , dtype = prompt .device , prompt .dtype
174
177
max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
175
178
with torch .device (device ):
176
- model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
179
+ model .setup_caches (max_batch_size = batch_size , max_seq_length = max_seq_length )
177
180
if is_speculative and draft_model is not model :
178
- draft_model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
181
+ draft_model .setup_caches (max_batch_size = batch_size , max_seq_length = max_seq_length )
179
182
180
183
# create an empty tensor of the expected final shape and fill in the current tokens
181
- empty = torch .empty (T_buf , dtype = dtype , device = device )
182
- empty [:T ] = prompt
184
+ empty = torch .empty (batch_size , T_buf , dtype = dtype , device = device )
185
+ # We are just making the same prompt for every batch
186
+ prompt = prompt .view (1 , - 1 ).repeat (batch_size , 1 )
187
+ empty [:, :T ] = prompt
183
188
seq = empty
184
189
input_pos = torch .arange (0 , T , device = device )
185
190
186
- next_token = prefill (model , prompt .view (1 , - 1 ), input_pos , ** sampling_kwargs ).clone ()
191
+ next_token = prefill (model , prompt .view (batch_size , - 1 ), input_pos , ** sampling_kwargs ).clone ()
187
192
if is_speculative :
188
- prefill (draft_model , prompt .view (1 , - 1 ), input_pos , ** sampling_kwargs )
189
- seq [T ] = next_token
193
+ prefill (draft_model , prompt .view (batch_size , - 1 ), input_pos , ** sampling_kwargs )
194
+ seq [:, T ] = next_token . squeeze ()
190
195
191
196
input_pos = torch .tensor ([T ], device = device , dtype = torch .int )
192
197
accept_counts = [0 ] * (speculate_k + 1 )
@@ -208,8 +213,8 @@ def generate(
208
213
input_pos = input_pos + num_added
209
214
next_token = next_tokens [- 1 ]
210
215
else :
211
- generated_tokens , _ = decode_n_tokens (model , next_token .view (1 , - 1 ), input_pos , max_new_tokens - 1 , callback = callback , ** sampling_kwargs )
212
- seq [T + 1 :T_new ] = torch .cat (generated_tokens )
216
+ generated_tokens , _ = decode_n_tokens (model , next_token .view (batch_size , - 1 ), input_pos , max_new_tokens - 1 , callback = callback , ** sampling_kwargs )
217
+ seq [:, T + 1 :T_new ] = torch .cat (generated_tokens , dim = - 1 )
213
218
214
219
generate_stats = {
215
220
'accept_counts' : accept_counts
@@ -256,6 +261,7 @@ def _load_model(checkpoint_path, device, precision, use_tp):
256
261
257
262
def _get_model_size (model ):
258
263
model_size = 0
264
+ params = 0
259
265
for name , child in model .named_children ():
260
266
if not isinstance (child , torch .nn .Embedding ):
261
267
model_size += sum (
@@ -264,15 +270,22 @@ def _get_model_size(model):
264
270
for p in itertools .chain (child .parameters (), child .buffers ())
265
271
]
266
272
)
267
- return model_size
273
+ params += sum (
274
+ [
275
+ p .numel ()
276
+ for p in itertools .chain (child .parameters (), child .buffers ())
277
+ ]
278
+ )
279
+ return model_size , params
268
280
269
281
B_INST , E_INST = "[INST]" , "[/INST]"
270
282
271
283
def main (
272
- prompt : str = "Hello, my name is" ,
284
+ prompt : Union [ int , str ] = "Hello, my name is" ,
273
285
interactive : bool = False ,
274
286
num_samples : int = 5 ,
275
287
max_new_tokens : int = 100 ,
288
+ batch_size : int = 1 ,
276
289
top_k : int = 200 ,
277
290
temperature : float = 0.8 ,
278
291
checkpoint_path : Path = Path ("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" ),
@@ -318,11 +331,15 @@ def main(
318
331
319
332
tokenizer = get_tokenizer (tokenizer_path , checkpoint_path )
320
333
321
- encoded = encode_tokens (tokenizer , prompt , bos = True , device = device )
322
- prompt_length = encoded .size (0 )
334
+ if isinstance (prompt , str ):
335
+ encoded = encode_tokens (tokenizer , prompt , bos = True , device = device )
336
+ else :
337
+ # generate a fully synthetic prompt
338
+ encoded = torch .randint (0 , 1024 , (prompt ,), device = device , dtype = torch .int64 )
339
+ prompt_length = encoded .size (- 1 )
323
340
324
341
torch .manual_seed (1234 )
325
- model_size = _get_model_size (model )
342
+ model_size , params = _get_model_size (model )
326
343
if compile :
327
344
if is_speculative and use_tp : # and ("cuda" in device):
328
345
torch ._inductor .config .triton .cudagraph_trees = False # Bug with cudagraph trees in this case
@@ -382,6 +399,7 @@ def callback(x):
382
399
model ,
383
400
encoded ,
384
401
max_new_tokens ,
402
+ batch_size = batch_size ,
385
403
draft_model = draft_model ,
386
404
speculate_k = speculate_k ,
387
405
interactive = interactive ,
@@ -402,21 +420,30 @@ def callback(x):
402
420
t = time .perf_counter () - t0
403
421
404
422
if not interactive :
405
- print (tokenizer .decode (y .tolist ()))
423
+ # Just displaying the first generation
424
+ if batch_size > 1 :
425
+ print ("Only displaying the first generation of the batch" )
426
+ print (tokenizer .decode (y [0 ].tolist ()))
406
427
else :
407
428
print ()
408
- tokens_generated = y .size (0 ) - prompt_length
409
- tokens_sec = tokens_generated / t
410
- aggregate_metrics ['tokens_per_sec' ].append (tokens_sec )
411
- print (f"Time for inference { i + 1 } : { t :.02f} sec total, { tokens_sec :.02f} tokens/sec" )
412
- print (f"Bandwidth achieved: { model_size * tokens_sec / 1e9 :.02f} GB/s" )
429
+ tokens_generated = y .size (- 1 ) - prompt_length
430
+ generated_tokens_sec = tokens_generated / t
431
+ aggregate_metrics ['tokens_per_sec' ].append (generated_tokens_sec )
432
+ print (f"Time for inference { i + 1 } : { t :.02f} sec total, { generated_tokens_sec :.02f} tokens/sec" )
433
+ print (f"Bandwidth achieved: { model_size * generated_tokens_sec / 1e9 :.02f} GB/s" )
434
+ total_tokens_sec = y .numel () / t
435
+ print (f"FLOPS achieved: { params * total_tokens_sec * 2 / 1e12 :.02f} TF/s" )
436
+ print ()
413
437
print ("==========" )
414
438
if is_speculative :
415
439
counts_aggregated = [sum (i ) for i in zip (* aggregate_metrics ['accept_counts' ])]
416
440
acceptance_probs = [i / sum (counts_aggregated ) for i in counts_aggregated ]
417
441
print (f"Acceptance probs: { acceptance_probs } " )
418
442
print (f"Mean Accepted: { sum ([idx * i for idx , i in enumerate (counts_aggregated )])/ sum (counts_aggregated )} " )
419
443
444
+ print (f"Batch Size: { batch_size } " )
445
+ print (f"Prompt Length: { prompt_length } " )
446
+ print (f"Generated tokens: { max_new_tokens } " )
420
447
print (f"Average tokens/sec: { torch .mean (torch .tensor (aggregate_metrics ['tokens_per_sec' ])).item ():.2f} " )
421
448
print (f"Memory used: { torch .cuda .max_memory_reserved () / 1e9 :.02f} GB" )
422
449
@@ -425,10 +452,17 @@ def callback(x):
425
452
import argparse
426
453
parser = argparse .ArgumentParser (description = 'Your CLI description.' )
427
454
428
- parser .add_argument ('--prompt' , type = str , default = "Hello, my name is" , help = 'Input prompt.' )
455
+ def int_or_str (x ):
456
+ try :
457
+ return int (x )
458
+ except :
459
+ return x
460
+
461
+ parser .add_argument ('--prompt' , type = int_or_str , default = "Hello, my name is" , help = "Input prompt. If it's an integer, will instead generate a synthetic prompt." )
429
462
parser .add_argument ('--interactive' , action = 'store_true' , help = 'Whether to launch in interactive mode' )
430
463
parser .add_argument ('--num_samples' , type = int , default = 5 , help = 'Number of samples.' )
431
464
parser .add_argument ('--max_new_tokens' , type = int , default = 200 , help = 'Maximum number of new tokens.' )
465
+ parser .add_argument ('--batch_size' , type = int , default = 1 , help = 'Batch size to benchmark with' )
432
466
parser .add_argument ('--top_k' , type = int , default = 200 , help = 'Top-k for sampling.' )
433
467
parser .add_argument ('--temperature' , type = float , default = 0.8 , help = 'Temperature for sampling.' )
434
468
parser .add_argument ('--checkpoint_path' , type = Path , default = Path ("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" ), help = 'Model checkpoint path.' )
@@ -441,7 +475,7 @@ def callback(x):
441
475
442
476
args = parser .parse_args ()
443
477
main (
444
- args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .top_k ,
478
+ args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .batch_size , args . top_k ,
445
479
args .temperature , args .checkpoint_path , args .compile , args .compile_prefill , args .profile , args .draft_checkpoint_path ,
446
480
args .speculate_k , args .device
447
481
)
0 commit comments