Skip to content

Commit e53ffb5

Browse files
committed
Merge branch 'main' into bf/flex-decoding-integrate
2 parents 2b7976c + 61c193d commit e53ffb5

File tree

2 files changed

+63
-29
lines changed

2 files changed

+63
-29
lines changed

eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def eval(
165165
Args:
166166
model (Transformer): The pre-trained language model to evaluate.
167167
tokenizer: The tokenizer to use for encoding/decoding text.
168-
task (str): The name of the evaluation task to perform.
168+
tasks (list): The names of the evaluation tasks to perform.
169169
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
170170
max_seq_length (Optional[int]): The maximum sequence length allowed for input text.
171171
@@ -208,7 +208,7 @@ def main(
208208
Args:
209209
checkpoint_path (Path): The path to the model checkpoint file to load.
210210
compile (bool): Whether or not to compile the model for optimization.
211-
task (Optional[str]): The name of the evaluation task or a list of tasks to perform.
211+
tasks (list): The names of the evaluation tasks to perform.
212212
limit (Optional[int]): The maximum number of samples to evaluate (None for all available).
213213
max_seq_length (Optional[int]): The maximum sequence length allowed for input text.
214214

generate.py

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
import time
99
from pathlib import Path
10-
from typing import Optional, Tuple
10+
from typing import Optional, Tuple, Union
1111

1212
import torch
1313
import torch._dynamo.config
@@ -25,7 +25,9 @@ def device_sync(device):
2525

2626
torch._inductor.config.coordinate_descent_tuning = True
2727
torch._inductor.config.triton.unique_kernel_names = True
28-
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
28+
# Experimental features to reduce compilation times, will be on by default in future
29+
torch._inductor.config.fx_graph_cache = True
30+
torch._functorch.config.enable_autograd_cache = True
2931

3032
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
3133

@@ -51,7 +53,7 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
5153
return probs
5254

5355
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
54-
probs = logits_to_probs(logits[0, -1], temperature, top_k)
56+
probs = logits_to_probs(logits[:, -1], temperature, top_k)
5557
idx_next = multinomial_sample_one_no_sync(probs)
5658
return idx_next, probs
5759

@@ -86,7 +88,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
8688
new_tokens.append(next_token.clone())
8789
callback(new_tokens[-1])
8890
new_probs.append(next_prob.clone())
89-
cur_token = next_token.view(1, -1)
91+
cur_token = next_token.clone()
9092

9193
return new_tokens, new_probs
9294

@@ -149,6 +151,7 @@ def generate(
149151
model: Transformer,
150152
prompt: torch.Tensor,
151153
max_new_tokens: int,
154+
batch_size: int,
152155
*,
153156
interactive: bool,
154157
draft_model: Transformer,
@@ -162,7 +165,7 @@ def generate(
162165

163166
is_speculative = draft_model is not None
164167
# create an empty tensor of the expected final shape and fill in the current tokens
165-
T = prompt.size(0)
168+
T = prompt.size(-1)
166169
T_new = T + max_new_tokens
167170
T_buf = roundup(T_new, 128) # round up to multiple of 128 to use flex_attention
168171
if interactive:
@@ -173,20 +176,22 @@ def generate(
173176
device, dtype = prompt.device, prompt.dtype
174177
max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
175178
with torch.device(device):
176-
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
179+
model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length)
177180
if is_speculative and draft_model is not model:
178-
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
181+
draft_model.setup_caches(max_batch_size=batch_size, max_seq_length=max_seq_length)
179182

180183
# create an empty tensor of the expected final shape and fill in the current tokens
181-
empty = torch.empty(T_buf, dtype=dtype, device=device)
182-
empty[:T] = prompt
184+
empty = torch.empty(batch_size, T_buf, dtype=dtype, device=device)
185+
# We are just making the same prompt for every batch
186+
prompt = prompt.view(1, -1).repeat(batch_size, 1)
187+
empty[:, :T] = prompt
183188
seq = empty
184189
input_pos = torch.arange(0, T, device=device)
185190

186-
next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone()
191+
next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone()
187192
if is_speculative:
188-
prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
189-
seq[T] = next_token
193+
prefill(draft_model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs)
194+
seq[:, T] = next_token.squeeze()
190195

191196
input_pos = torch.tensor([T], device=device, dtype=torch.int)
192197
accept_counts = [0] * (speculate_k + 1)
@@ -208,8 +213,8 @@ def generate(
208213
input_pos = input_pos + num_added
209214
next_token = next_tokens[-1]
210215
else:
211-
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
212-
seq[T + 1:T_new] = torch.cat(generated_tokens)
216+
generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
217+
seq[:, T + 1:T_new] = torch.cat(generated_tokens, dim=-1)
213218

214219
generate_stats = {
215220
'accept_counts': accept_counts
@@ -256,6 +261,7 @@ def _load_model(checkpoint_path, device, precision, use_tp):
256261

257262
def _get_model_size(model):
258263
model_size = 0
264+
params = 0
259265
for name, child in model.named_children():
260266
if not isinstance(child, torch.nn.Embedding):
261267
model_size += sum(
@@ -264,15 +270,22 @@ def _get_model_size(model):
264270
for p in itertools.chain(child.parameters(), child.buffers())
265271
]
266272
)
267-
return model_size
273+
params += sum(
274+
[
275+
p.numel()
276+
for p in itertools.chain(child.parameters(), child.buffers())
277+
]
278+
)
279+
return model_size, params
268280

269281
B_INST, E_INST = "[INST]", "[/INST]"
270282

271283
def main(
272-
prompt: str = "Hello, my name is",
284+
prompt: Union[int, str] = "Hello, my name is",
273285
interactive: bool = False,
274286
num_samples: int = 5,
275287
max_new_tokens: int = 100,
288+
batch_size: int = 1,
276289
top_k: int = 200,
277290
temperature: float = 0.8,
278291
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
@@ -318,11 +331,15 @@ def main(
318331

319332
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)
320333

321-
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
322-
prompt_length = encoded.size(0)
334+
if isinstance(prompt, str):
335+
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
336+
else:
337+
# generate a fully synthetic prompt
338+
encoded = torch.randint(0, 1024, (prompt,), device=device, dtype=torch.int64)
339+
prompt_length = encoded.size(-1)
323340

324341
torch.manual_seed(1234)
325-
model_size = _get_model_size(model)
342+
model_size, params = _get_model_size(model)
326343
if compile:
327344
if is_speculative and use_tp: # and ("cuda" in device):
328345
torch._inductor.config.triton.cudagraph_trees = False # Bug with cudagraph trees in this case
@@ -382,6 +399,7 @@ def callback(x):
382399
model,
383400
encoded,
384401
max_new_tokens,
402+
batch_size=batch_size,
385403
draft_model=draft_model,
386404
speculate_k=speculate_k,
387405
interactive=interactive,
@@ -402,21 +420,30 @@ def callback(x):
402420
t = time.perf_counter() - t0
403421

404422
if not interactive:
405-
print(tokenizer.decode(y.tolist()))
423+
# Just displaying the first generation
424+
if batch_size > 1:
425+
print("Only displaying the first generation of the batch")
426+
print(tokenizer.decode(y[0].tolist()))
406427
else:
407428
print()
408-
tokens_generated = y.size(0) - prompt_length
409-
tokens_sec = tokens_generated / t
410-
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
411-
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
412-
print(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
429+
tokens_generated = y.size(-1) - prompt_length
430+
generated_tokens_sec = tokens_generated / t
431+
aggregate_metrics['tokens_per_sec'].append(generated_tokens_sec)
432+
print(f"Time for inference {i + 1}: {t:.02f} sec total, {generated_tokens_sec:.02f} tokens/sec")
433+
print(f"Bandwidth achieved: {model_size * generated_tokens_sec / 1e9:.02f} GB/s")
434+
total_tokens_sec = y.numel() / t
435+
print(f"FLOPS achieved: {params * total_tokens_sec * 2 / 1e12:.02f} TF/s")
436+
print()
413437
print("==========")
414438
if is_speculative:
415439
counts_aggregated = [sum(i) for i in zip(*aggregate_metrics['accept_counts'])]
416440
acceptance_probs = [i/sum(counts_aggregated) for i in counts_aggregated]
417441
print(f"Acceptance probs: {acceptance_probs}")
418442
print(f"Mean Accepted: {sum([idx * i for idx, i in enumerate(counts_aggregated)])/sum(counts_aggregated)}")
419443

444+
print(f"Batch Size: {batch_size}")
445+
print(f"Prompt Length: {prompt_length}")
446+
print(f"Generated tokens: {max_new_tokens}")
420447
print(f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item():.2f}")
421448
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
422449

@@ -425,10 +452,17 @@ def callback(x):
425452
import argparse
426453
parser = argparse.ArgumentParser(description='Your CLI description.')
427454

428-
parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
455+
def int_or_str(x):
456+
try:
457+
return int(x)
458+
except:
459+
return x
460+
461+
parser.add_argument('--prompt', type=int_or_str, default="Hello, my name is", help="Input prompt. If it's an integer, will instead generate a synthetic prompt.")
429462
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
430463
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
431464
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
465+
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with')
432466
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
433467
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
434468
parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
@@ -441,7 +475,7 @@ def callback(x):
441475

442476
args = parser.parse_args()
443477
main(
444-
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
478+
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
445479
args.temperature, args.checkpoint_path, args.compile, args.compile_prefill, args.profile, args.draft_checkpoint_path,
446480
args.speculate_k, args.device
447481
)

0 commit comments

Comments
 (0)