55
55
# Using model name to identify the model to load, for example "llama2-7b-chat".
56
56
# You can change it to other values listed below.
57
57
# For details on the name-to-distribution mapping, see README.md or models.json.
58
+
59
+ # Name : HF distribution name, dtype, and model dimension
58
60
NAME_TO_DISTRIBUTION_AND_DTYPE = {
59
- "llama2-7b-chat" : ("meta-llama/Llama-2-7b-chat-hf" , torch .float16 ),
60
- "llama3" : ("meta-llama/Meta-Llama-3-8B-Instruct" , torch .bfloat16 ),
61
+ "llama2-7b-chat" : ("meta-llama/Llama-2-7b-chat-hf" , torch .float16 , 4096 ),
62
+ "llama3" : ("meta-llama/Meta-Llama-3-8B-Instruct" , torch .bfloat16 , 4096 ),
63
+ "llama3-70b" : ("meta-llama/Meta-Llama-3-70B-Instruct" , torch .bfloat16 , 8192 ),
61
64
}
62
65
63
66
@@ -314,8 +317,12 @@ def main(args):
314
317
gpu_memory_monitor = GPUMemoryMonitor ("cuda" )
315
318
logger .info (f"{ color .yellow } { gpu_memory_monitor .get_device_info ()} { color .reset } " )
316
319
317
- distribution , model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE [model_name ]
318
- logger .info (f"Using model weights from { distribution } and dtype { model_dtype } " )
320
+ distribution , model_dtype , model_dimension = NAME_TO_DISTRIBUTION_AND_DTYPE [
321
+ model_name
322
+ ]
323
+ logger .info (
324
+ f"Using model weights from { distribution } , dtype { model_dtype } and model dimension { model_dimension } "
325
+ )
319
326
320
327
# Model-level config
321
328
model_config = ModelArgs .from_name (distribution )
@@ -338,6 +345,7 @@ def main(args):
338
345
339
346
# Tensor parallel is enabled in this program
340
347
tp_degree = world_size // pp_degree
348
+ logger .info (f"Using TP degree { tp_degree } and PP degree { pp_degree } " )
341
349
342
350
# Create device mesh
343
351
mesh_dimensions = (pp_degree , tp_degree )
@@ -388,7 +396,6 @@ def main(args):
388
396
# sense. Thus it is interchangeable with micro-batch size below.
389
397
batch_size = len (prompt )
390
398
seqlen_prefill = 1024 # sequence length
391
- dim = 4096 # embedding dimension
392
399
393
400
# Setup KV caches (after model distribution)
394
401
# The number of cache lanes is the same as the maximum number of
@@ -419,7 +426,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
419
426
0 , config .vocab_size , (batch_size , seqlen ), device = device
420
427
)
421
428
activation = torch .rand (
422
- batch_size , seqlen , dim , device = device , dtype = model_dtype
429
+ batch_size , seqlen , model_dimension , device = device , dtype = model_dtype
423
430
)
424
431
logits = torch .rand (
425
432
batch_size , seqlen , config .vocab_size , device = device , dtype = model_dtype
0 commit comments