Skip to content

Commit ba09114

Browse files
pd: fix local_rank and in mutlti nodes training (#4811)
1. get local rank from `PADDLE_LOCAL_RANK` environment variable instead of `get_rank()`(which will return global rank). 2. disable gradient synchronization in forward-backward and synchronize manually before optimizer update 4. update parallel training tutorial(multi-node multi-GPU) in document <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved gradient synchronization in distributed training for multi-process setups. - Updated local rank assignment to use environment variables for enhanced compatibility. - **Documentation** - Added an example using `mpirun` and a sample shell script to the parallel training guide for distributed training launch. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: HydrogenSulfate <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 617d3e2 commit ba09114

File tree

3 files changed

+41
-10
lines changed

3 files changed

+41
-10
lines changed

deepmd/pd/train/training.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import contextlib
23
import functools
34
import logging
45
import time
@@ -18,6 +19,7 @@
1819
from paddle.distributed import (
1920
fleet,
2021
)
22+
from paddle.distributed.fleet.utils import hybrid_parallel_util as hpu
2123
from paddle.framework import (
2224
core,
2325
)
@@ -741,16 +743,30 @@ def step(_step_id, task_key="Default") -> None:
741743
pref_lr = _lr.start_lr
742744
else:
743745
pref_lr = cur_lr
744-
with nvprof_context(enable_profiling, "Forward pass"):
745-
model_pred, loss, more_loss = self.wrapper(
746-
**input_dict,
747-
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
748-
label=label_dict,
749-
task_key=task_key,
750-
)
751746

752-
with nvprof_context(enable_profiling, "Backward pass"):
753-
loss.backward()
747+
# disable synchronization in forward-backward manually
748+
# as derivatives exist in model forward
749+
no_sync_context = (
750+
self.wrapper.no_sync
751+
if self.world_size > 1
752+
else contextlib.nullcontext
753+
)
754+
with no_sync_context():
755+
with nvprof_context(enable_profiling, "Forward pass"):
756+
model_pred, loss, more_loss = self.wrapper(
757+
**input_dict,
758+
cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION),
759+
label=label_dict,
760+
task_key=task_key,
761+
)
762+
763+
with nvprof_context(enable_profiling, "Backward pass"):
764+
loss.backward()
765+
766+
# fuse + allreduce manually before optimization if use DDP + no_sync
767+
# details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622
768+
if self.world_size > 1:
769+
hpu.fused_allreduce_gradients(list(self.wrapper.parameters()), None)
754770

755771
if self.gradient_max_norm > 0.0:
756772
with nvprof_context(enable_profiling, "Gradient clip"):

deepmd/pd/utils/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
ncpus = os.cpu_count()
2828
NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(0, ncpus)))
2929
# Make sure DDP uses correct device if applicable
30-
LOCAL_RANK = paddle.distributed.get_rank()
30+
LOCAL_RANK = int(os.environ.get("PADDLE_LOCAL_RANK", 0))
3131

3232
if os.environ.get("DEVICE") == "cpu" or paddle.device.cuda.device_count() <= 0:
3333
DEVICE = "cpu"

doc/train/parallel-training.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,21 @@ NUM_WORKERS=0 HDF5_USE_FILE_LOCKING=0 python -m paddle.distributed.launch \
218218
dp --pd train input.json
219219
```
220220

221+
or you can wrapper the training script with `mpirun`:
222+
223+
```bash
224+
# ----- train_pp.sh -------
225+
unset CUDA_DEVICE_MAX_CONNECTIONS
226+
python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" --log_dir logs dp --pd train input_torch.json -l train_pp.log
227+
# -------------------------
228+
```
229+
230+
Then, run the script on the first node with:
231+
232+
```bash
233+
mpirun run_pp.sh
234+
```
235+
221236
:::{note}
222237

223238
If `NUM_WORKERS` is too large, it may cause the program to be terminated by the system;

0 commit comments

Comments
 (0)