7
7
import sys
8
8
import time
9
9
from pathlib import Path
10
- from typing import Optional , Tuple
10
+ from typing import Optional , Tuple , Union
11
11
12
12
import torch
13
13
import torch ._dynamo .config
@@ -24,7 +24,9 @@ def device_sync(device):
24
24
25
25
torch ._inductor .config .coordinate_descent_tuning = True
26
26
torch ._inductor .config .triton .unique_kernel_names = True
27
- torch ._inductor .config .fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
27
+ # Experimental features to reduce compilation times, will be on by default in future
28
+ torch ._inductor .config .fx_graph_cache = True
29
+ torch ._functorch .config .enable_autograd_cache = True
28
30
29
31
default_device = 'cuda' if torch .cuda .is_available () else 'cpu'
30
32
@@ -50,7 +52,7 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
50
52
return probs
51
53
52
54
def sample (logits , temperature : float = 1.0 , top_k : Optional [int ] = None ):
53
- probs = logits_to_probs (logits [0 , - 1 ], temperature , top_k )
55
+ probs = logits_to_probs (logits [: , - 1 ], temperature , top_k )
54
56
idx_next = multinomial_sample_one_no_sync (probs )
55
57
return idx_next , probs
56
58
@@ -76,7 +78,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
76
78
new_tokens .append (next_token .clone ())
77
79
callback (new_tokens [- 1 ])
78
80
new_probs .append (next_prob .clone ())
79
- cur_token = next_token .view ( 1 , - 1 )
81
+ cur_token = next_token .clone ( )
80
82
81
83
return new_tokens , new_probs
82
84
@@ -139,6 +141,7 @@ def generate(
139
141
model : Transformer ,
140
142
prompt : torch .Tensor ,
141
143
max_new_tokens : int ,
144
+ batch_size : int ,
142
145
* ,
143
146
interactive : bool ,
144
147
draft_model : Transformer ,
@@ -152,7 +155,7 @@ def generate(
152
155
153
156
is_speculative = draft_model is not None
154
157
# create an empty tensor of the expected final shape and fill in the current tokens
155
- T = prompt .size (0 )
158
+ T = prompt .size (- 1 )
156
159
T_new = T + max_new_tokens
157
160
if interactive :
158
161
max_seq_length = 350
@@ -162,20 +165,22 @@ def generate(
162
165
device , dtype = prompt .device , prompt .dtype
163
166
max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
164
167
with torch .device (device ):
165
- model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
168
+ model .setup_caches (max_batch_size = batch_size , max_seq_length = max_seq_length )
166
169
if is_speculative and draft_model is not model :
167
- draft_model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
170
+ draft_model .setup_caches (max_batch_size = batch_size , max_seq_length = max_seq_length )
168
171
169
172
# create an empty tensor of the expected final shape and fill in the current tokens
170
- empty = torch .empty (T_new , dtype = dtype , device = device )
171
- empty [:T ] = prompt
173
+ empty = torch .empty (batch_size , T_new , dtype = dtype , device = device )
174
+ # We are just making the same prompt for every batch
175
+ prompt = prompt .view (1 , - 1 ).repeat (batch_size , 1 )
176
+ empty [:, :T ] = prompt
172
177
seq = empty
173
178
input_pos = torch .arange (0 , T , device = device )
174
179
175
- next_token = prefill (model , prompt .view (1 , - 1 ), input_pos , ** sampling_kwargs ).clone ()
180
+ next_token = prefill (model , prompt .view (batch_size , - 1 ), input_pos , ** sampling_kwargs ).clone ()
176
181
if is_speculative :
177
- prefill (draft_model , prompt .view (1 , - 1 ), input_pos , ** sampling_kwargs )
178
- seq [T ] = next_token
182
+ prefill (draft_model , prompt .view (batch_size , - 1 ), input_pos , ** sampling_kwargs )
183
+ seq [:, T ] = next_token . squeeze ()
179
184
180
185
input_pos = torch .tensor ([T ], device = device , dtype = torch .int )
181
186
accept_counts = [0 ] * (speculate_k + 1 )
@@ -197,8 +202,8 @@ def generate(
197
202
input_pos = input_pos + num_added
198
203
next_token = next_tokens [- 1 ]
199
204
else :
200
- generated_tokens , _ = decode_n_tokens (model , next_token .view (1 , - 1 ), input_pos , max_new_tokens - 1 , callback = callback , ** sampling_kwargs )
201
- seq [T + 1 :] = torch .cat (generated_tokens )
205
+ generated_tokens , _ = decode_n_tokens (model , next_token .view (batch_size , - 1 ), input_pos , max_new_tokens - 1 , callback = callback , ** sampling_kwargs )
206
+ seq [:, T + 1 :] = torch .cat (generated_tokens , dim = - 1 )
202
207
203
208
generate_stats = {
204
209
'accept_counts' : accept_counts
@@ -245,6 +250,7 @@ def _load_model(checkpoint_path, device, precision, use_tp):
245
250
246
251
def _get_model_size (model ):
247
252
model_size = 0
253
+ params = 0
248
254
for name , child in model .named_children ():
249
255
if not isinstance (child , torch .nn .Embedding ):
250
256
model_size += sum (
@@ -253,15 +259,22 @@ def _get_model_size(model):
253
259
for p in itertools .chain (child .parameters (), child .buffers ())
254
260
]
255
261
)
256
- return model_size
262
+ params += sum (
263
+ [
264
+ p .numel ()
265
+ for p in itertools .chain (child .parameters (), child .buffers ())
266
+ ]
267
+ )
268
+ return model_size , params
257
269
258
270
B_INST , E_INST = "[INST]" , "[/INST]"
259
271
260
272
def main (
261
- prompt : str = "Hello, my name is" ,
273
+ prompt : Union [ int , str ] = "Hello, my name is" ,
262
274
interactive : bool = False ,
263
275
num_samples : int = 5 ,
264
276
max_new_tokens : int = 100 ,
277
+ batch_size : int = 1 ,
265
278
top_k : int = 200 ,
266
279
temperature : float = 0.8 ,
267
280
checkpoint_path : Path = Path ("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" ),
@@ -307,11 +320,15 @@ def main(
307
320
308
321
tokenizer = get_tokenizer (tokenizer_path , checkpoint_path )
309
322
310
- encoded = encode_tokens (tokenizer , prompt , bos = True , device = device )
311
- prompt_length = encoded .size (0 )
323
+ if isinstance (prompt , str ):
324
+ encoded = encode_tokens (tokenizer , prompt , bos = True , device = device )
325
+ else :
326
+ # generate a fully synthetic prompt
327
+ encoded = torch .randint (0 , 1024 , (prompt ,), device = device , dtype = torch .int64 )
328
+ prompt_length = encoded .size (- 1 )
312
329
313
330
torch .manual_seed (1234 )
314
- model_size = _get_model_size (model )
331
+ model_size , params = _get_model_size (model )
315
332
if compile :
316
333
if is_speculative and use_tp : # and ("cuda" in device):
317
334
torch ._inductor .config .triton .cudagraph_trees = False # Bug with cudagraph trees in this case
@@ -371,6 +388,7 @@ def callback(x):
371
388
model ,
372
389
encoded ,
373
390
max_new_tokens ,
391
+ batch_size = batch_size ,
374
392
draft_model = draft_model ,
375
393
speculate_k = speculate_k ,
376
394
interactive = interactive ,
@@ -391,21 +409,30 @@ def callback(x):
391
409
t = time .perf_counter () - t0
392
410
393
411
if not interactive :
394
- print (tokenizer .decode (y .tolist ()))
412
+ # Just displaying the first generation
413
+ if batch_size > 1 :
414
+ print ("Only displaying the first generation of the batch" )
415
+ print (tokenizer .decode (y [0 ].tolist ()))
395
416
else :
396
417
print ()
397
- tokens_generated = y .size (0 ) - prompt_length
398
- tokens_sec = tokens_generated / t
399
- aggregate_metrics ['tokens_per_sec' ].append (tokens_sec )
400
- print (f"Time for inference { i + 1 } : { t :.02f} sec total, { tokens_sec :.02f} tokens/sec" )
401
- print (f"Bandwidth achieved: { model_size * tokens_sec / 1e9 :.02f} GB/s" )
418
+ tokens_generated = y .size (- 1 ) - prompt_length
419
+ generated_tokens_sec = tokens_generated / t
420
+ aggregate_metrics ['tokens_per_sec' ].append (generated_tokens_sec )
421
+ print (f"Time for inference { i + 1 } : { t :.02f} sec total, { generated_tokens_sec :.02f} tokens/sec" )
422
+ print (f"Bandwidth achieved: { model_size * generated_tokens_sec / 1e9 :.02f} GB/s" )
423
+ total_tokens_sec = y .numel () / t
424
+ print (f"FLOPS achieved: { params * total_tokens_sec * 2 / 1e12 :.02f} TF/s" )
425
+ print ()
402
426
print ("==========" )
403
427
if is_speculative :
404
428
counts_aggregated = [sum (i ) for i in zip (* aggregate_metrics ['accept_counts' ])]
405
429
acceptance_probs = [i / sum (counts_aggregated ) for i in counts_aggregated ]
406
430
print (f"Acceptance probs: { acceptance_probs } " )
407
431
print (f"Mean Accepted: { sum ([idx * i for idx , i in enumerate (counts_aggregated )])/ sum (counts_aggregated )} " )
408
432
433
+ print (f"Batch Size: { batch_size } " )
434
+ print (f"Prompt Length: { prompt_length } " )
435
+ print (f"Generated tokens: { max_new_tokens } " )
409
436
print (f"Average tokens/sec: { torch .mean (torch .tensor (aggregate_metrics ['tokens_per_sec' ])).item ():.2f} " )
410
437
print (f"Memory used: { torch .cuda .max_memory_reserved () / 1e9 :.02f} GB" )
411
438
@@ -414,10 +441,17 @@ def callback(x):
414
441
import argparse
415
442
parser = argparse .ArgumentParser (description = 'Your CLI description.' )
416
443
417
- parser .add_argument ('--prompt' , type = str , default = "Hello, my name is" , help = 'Input prompt.' )
444
+ def int_or_str (x ):
445
+ try :
446
+ return int (x )
447
+ except :
448
+ return x
449
+
450
+ parser .add_argument ('--prompt' , type = int_or_str , default = "Hello, my name is" , help = "Input prompt. If it's an integer, will instead generate a synthetic prompt." )
418
451
parser .add_argument ('--interactive' , action = 'store_true' , help = 'Whether to launch in interactive mode' )
419
452
parser .add_argument ('--num_samples' , type = int , default = 5 , help = 'Number of samples.' )
420
453
parser .add_argument ('--max_new_tokens' , type = int , default = 200 , help = 'Maximum number of new tokens.' )
454
+ parser .add_argument ('--batch_size' , type = int , default = 1 , help = 'Batch size to benchmark with' )
421
455
parser .add_argument ('--top_k' , type = int , default = 200 , help = 'Top-k for sampling.' )
422
456
parser .add_argument ('--temperature' , type = float , default = 0.8 , help = 'Temperature for sampling.' )
423
457
parser .add_argument ('--checkpoint_path' , type = Path , default = Path ("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" ), help = 'Model checkpoint path.' )
@@ -430,7 +464,7 @@ def callback(x):
430
464
431
465
args = parser .parse_args ()
432
466
main (
433
- args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .top_k ,
467
+ args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .batch_size , args . top_k ,
434
468
args .temperature , args .checkpoint_path , args .compile , args .compile_prefill , args .profile , args .draft_checkpoint_path ,
435
469
args .speculate_k , args .device
436
470
)
0 commit comments