Skip to content

Commit 0f70507

Browse files
authored
dynamically update torch.compile cache config to ensure async tp support, enhance async tp UX (#471)
This PR adds some enhancements for supporting async tp: 1 - if async tp is active, auto updates the torch.dynamo cache limit to 10K. If this is not updated, async tp will not be activated on larger models as it will quietly stop compilation due to 'cache limit reached' with no info for the user. This config update is logged. 2 - if async tp is enabled, verifies that torch.compile is set to true for this job config. If not, it warns and then activates torch.compile to ensure user gets working async tp. (see WARNING in below screenshot) <img width="1345" alt="Screenshot 2024-07-20 at 4 33 04 PM" src="https://github.com/user-attachments/assets/26e5a48e-4bb8-4f33-b1b5-8939c1517c1d"> 3 - Updates the 'Applied Tensor Parallel' to the model to be 'Applied Async Tensor Parallel' when async tp is active to make it clear in the logs which TP is active. (see above screenshot)
1 parent d76b77f commit 0f70507

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,13 +413,27 @@ def apply_tp(
413413
parallelize_plan=layer_plan,
414414
)
415415

416+
# updates expressly for async tensor parallel
416417
if job_config.experimental.enable_async_tensor_parallel:
417418
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
418419

420+
torch._dynamo.config.cache_size_limit = 10000
421+
logger.info(
422+
"Updating torch._dynamo.config.cache_size_limit to 10000 to support Async TP"
423+
)
424+
419425
torch._inductor.config._micro_pipeline_tp = True
420426
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
421427

422-
logger.info("Applied Tensor Parallelism to the model")
428+
if not job_config.training.compile:
429+
logger.warning(
430+
"Async TP requires compilation...auto enabling compile = True for this job to resolve."
431+
)
432+
job_config.training.compile = True
433+
434+
logger.info(
435+
f"Applied{' Async ' if job_config.experimental.enable_async_tensor_parallel else ' '}Tensor Parallelism to the model"
436+
)
423437
return model
424438

425439

train_configs/debug_model.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M)
4343

4444
[experimental]
4545
pipeline_parallel_degree = 1
46+
enable_async_tensor_parallel = false
4647

4748
[checkpoint]
4849
enable_checkpoint = false

0 commit comments

Comments
 (0)