-
Notifications
You must be signed in to change notification settings - Fork 570
Open
Labels
Description
Bug description
When running FSDP2+TP+compile on DSV3 models, I hit the following warning
rank0]:[W929 11:15:17.329591151 ProcessGroup.cpp:367] Warning: At the time of process termination, there are still 495 unwaited collective calls. Please review your program to ensure that:
[rank0]:1. c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,
[rank0]:2. c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective called under `with allow_inflight_collective_as_graph_input_ctx():`,
[rank0]:before the output tensors of the collective are used. (function ~WorkRegistry)
[rank0]:[W929 11:15:17.329693546 ProcessGroup.cpp:367] Warning: At the time of process termination, there are still 5 unwaited collective calls. Please review your program to ensure that:
[rank0]:1. c10d_functional.wait_tensor() is invoked on all tensors returned from c10d_functional collective,
[rank0]:2. c10d_functional.wait_tensor() is invoked on all output tensors of async_op=True torch.distributed collective called under `with allow_inflight_collective_as_graph_input_ctx():`,
[rank0]:before the output tensors of the collective are used. (function ~WorkRegistry)
Command to reproduce it:
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseek_v3 --parallelism.tensor_parallel_degree 2 --parallelism.expert_parallel_degree 1 --compile.enable --activation_checkpoint.mode "none" --training.steps 100 --training.seed 42
A few explorations:
- DSV3+FSDP2+TP+eager: no warning
- DSV3+FSDP2+TP+compile: unwaited collective warning
- DSV3+FSDP2+EP+compile: no warning
- Llama3+FSDP2+TP+compile: no warning
Versions
See above