Skip to content

Commit 61c193d

Browse files
committed
Add support for batched generation and synthetic long prompt
1 parent bc04265 commit 61c193d

File tree

1 file changed

+61
-27
lines changed

1 file changed

+61
-27
lines changed

generate.py

+61-27
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
import time
99
from pathlib import Path
10-
from typing import Optional, Tuple
10+
from typing import Optional, Tuple, Union
1111

1212
import torch
1313
import torch._dynamo.config
@@ -24,7 +24,9 @@ def device_sync(device):
2424

2525
torch._inductor.config.coordinate_descent_tuning = True
2626
torch._inductor.config.triton.unique_kernel_names = True
27-
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
27+
# Experimental features to reduce compilation times, will be on by default in future
28+
torch._inductor.config.fx_graph_cache = True
29+
torch._functorch.config.enable_autograd_cache = True
2830

2931
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
3032

@@ -50,7 +52,7 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
5052
return probs
5153

5254
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
53-
probs = logits_to_probs(logits[0, -1], temperature, top_k)
55+
probs = logits_to_probs(logits[:, -1], temperature, top_k)
5456
idx_next = multinomial_sample_one_no_sync(probs)
5557
return idx_next, probs
5658

@@ -76,7 +78,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
7678
new_tokens.append(next_token.clone())
7779
callback(new_tokens[-1])
7880
new_probs.append(next_prob.clone())
79-
cur_token = next_token.view(1, -1)
81+
cur_token = next_token.clone()
8082

8183
return new_tokens, new_probs
8284

@@ -139,6 +141,7 @@ def generate(
139141
model: Transformer,
140142
prompt: torch.Tensor,
141143
max_new_tokens: int,
144+
batch_size: int,
142145
*,
143146
interactive: bool,
144147
draft_model: Transformer,
@@ -152,7 +155,7 @@ def generate(
152155

153156
is_speculative = draft_model is not None
154157
# create an empty tensor of the expected final shape and fill in the current tokens
155-
T = prompt.size(0)
158+
T = prompt.size(-1)
156159
T_new = T + max_new_tokens
157160
if interactive:
158161
max_seq_length = 350
@@ -162,20 +165,22 @@ def generate(
162165
device, dtype = prompt.device, prompt.dtype
163166
max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
164167
with torch.device(device):
165-
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
168+
model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length)
166169
if is_speculative and draft_model is not model:
167-
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
170+
draft_model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length)
168171

169172
# create an empty tensor of the expected final shape and fill in the current tokens
170-
empty = torch.empty(T_new, dtype=dtype, device=device)
171-
empty[:T] = prompt
173+
empty = torch.empty(batch_size, T_new, dtype=dtype, device=device)
174+
# We are just making the same prompt for every batch
175+
prompt = prompt.view(1, -1).repeat(batch_size, 1)
176+
empty[:, :T] = prompt
172177
seq = empty
173178
input_pos = torch.arange(0, T, device=device)
174179

175-
next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone()
180+
next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone()
176181
if is_speculative:
177-
prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
178-
seq[T] = next_token
182+
prefill(draft_model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs)
183+
seq[:, T] = next_token.squeeze()
179184

180185
input_pos = torch.tensor([T], device=device, dtype=torch.int)
181186
accept_counts = [0] * (speculate_k + 1)
@@ -197,8 +202,8 @@ def generate(
197202
input_pos = input_pos + num_added
198203
next_token = next_tokens[-1]
199204
else:
200-
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
201-
seq[T + 1:] = torch.cat(generated_tokens)
205+
generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
206+
seq[:, T + 1:] = torch.cat(generated_tokens, dim=-1)
202207

203208
generate_stats = {
204209
'accept_counts': accept_counts
@@ -245,6 +250,7 @@ def _load_model(checkpoint_path, device, precision, use_tp):
245250

246251
def _get_model_size(model):
247252
model_size = 0
253+
params = 0
248254
for name, child in model.named_children():
249255
if not isinstance(child, torch.nn.Embedding):
250256
model_size += sum(
@@ -253,15 +259,22 @@ def _get_model_size(model):
253259
for p in itertools.chain(child.parameters(), child.buffers())
254260
]
255261
)
256-
return model_size
262+
params += sum(
263+
[
264+
p.numel()
265+
for p in itertools.chain(child.parameters(), child.buffers())
266+
]
267+
)
268+
return model_size, params
257269

258270
B_INST, E_INST = "[INST]", "[/INST]"
259271

260272
def main(
261-
prompt: str = "Hello, my name is",
273+
prompt: Union[int, str] = "Hello, my name is",
262274
interactive: bool = False,
263275
num_samples: int = 5,
264276
max_new_tokens: int = 100,
277+
batch_size: int = 1,
265278
top_k: int = 200,
266279
temperature: float = 0.8,
267280
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
@@ -307,11 +320,15 @@ def main(
307320

308321
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
309322

310-
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
311-
prompt_length = encoded.size(0)
323+
if isinstance(prompt, str):
324+
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
325+
else:
326+
# generate a fully synthetic prompt
327+
encoded = torch.randint(0, 1024, (prompt,), device=device, dtype=torch.int64)
328+
prompt_length = encoded.size(-1)
312329

313330
torch.manual_seed(1234)
314-
model_size = _get_model_size(model)
331+
model_size, params = _get_model_size(model)
315332
if compile:
316333
if is_speculative and use_tp: # and ("cuda" in device):
317334
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
@@ -371,6 +388,7 @@ def callback(x):
371388
model,
372389
encoded,
373390
max_new_tokens,
391+
batch_size=batch_size,
374392
draft_model=draft_model,
375393
speculate_k=speculate_k,
376394
interactive=interactive,
@@ -391,21 +409,30 @@ def callback(x):
391409
t = time.perf_counter() - t0
392410

393411
if not interactive:
394-
print(tokenizer.decode(y.tolist()))
412+
# Just displaying the first generation
413+
if batch_size > 1:
414+
print("Only displaying the first generation of the batch")
415+
print(tokenizer.decode(y[0].tolist()))
395416
else:
396417
print()
397-
tokens_generated = y.size(0) - prompt_length
398-
tokens_sec = tokens_generated / t
399-
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
400-
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
401-
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
418+
tokens_generated = y.size(-1) - prompt_length
419+
generated_tokens_sec = tokens_generated / t
420+
aggregate_metrics['tokens_per_sec'].append(generated_tokens_sec)
421+
print(f"Time for inference {i + 1}: {t:.02f} sec total, {generated_tokens_sec:.02f} tokens/sec")
422+
print(f"Bandwidth achieved: {model_size * generated_tokens_sec / 1e9:.02f} GB/s")
423+
total_tokens_sec = y.numel() / t
424+
print(f"FLOPS achieved: {params * total_tokens_sec * 2 / 1e12:.02f} TF/s")
425+
print()
402426
print("==========")
403427
if is_speculative:
404428
counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])]
405429
acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated]
406430
print(f"Acceptance probs: {acceptance_probs}")
407431
print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}")
408432

433+
print(f"Batch Size: {batch_size}")
434+
print(f"Prompt Length: {prompt_length}")
435+
print(f"Generated tokens: {max_new_tokens}")
409436
print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
410437
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
411438

@@ -414,10 +441,17 @@ def callback(x):
414441
import argparse
415442
parser = argparse.ArgumentParser(description='Your CLI description.')
416443

417-
parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
444+
def int_or_str(x):
445+
try:
446+
return int(x)
447+
except:
448+
return x
449+
450+
parser.add_argument('--prompt', type=int_or_str, default="Hello, my name is", help="Input prompt. If it's an integer, will instead generate a synthetic prompt.")
418451
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
419452
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
420453
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
454+
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with')
421455
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
422456
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
423457
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
@@ -430,7 +464,7 @@ def callback(x):
430464

431465
args = parser.parse_args()
432466
main(
433-
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
467+
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
434468
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,
435469
args.speculate_k, args.device
436470
)

0 commit comments

Comments
 (0)