Skip to content

Commit 2733555

Browse files
zhaojuanmaofacebook-github-bot
authored andcommitted
replace all_gather with more efficient collective api _all_gather_base (pytorch#57769)
Summary: Pull Request resolved: pytorch#57769 _all_gather_base saved copies in all_gather, so it is more efficient Test Plan: unit test Reviewed By: SciPioneer Differential Revision: D28227193 fbshipit-source-id: ddd8590095a5b45676497a71ed792a457f9825c6
1 parent c58709b commit 2733555

File tree

3 files changed

+37
-11
lines changed

3 files changed

+37
-11
lines changed

torch/csrc/distributed/c10d/init.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -1194,6 +1194,10 @@ that adds a prefix to each key inserted to the store.
11941194
},
11951195
py::arg("timeout") = ::c10d::kUnsetTimeout,
11961196
py::arg("wait_all_ranks") = false,
1197+
py::call_guard<py::gil_scoped_release>())
1198+
.def(
1199+
"_get_backend_name",
1200+
&::c10d::ProcessGroup::getBackendName,
11971201
py::call_guard<py::gil_scoped_release>());
11981202

11991203
// base ProcessGroup::Options binding

torch/nn/modules/_functions.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,31 @@ def forward(self, input, weight, bias, running_mean, running_var, eps, momentum,
2727
num_channels = input.shape[1]
2828
# C, C, 1 -> (2C + 1)
2929
combined = torch.cat([mean, invstd, count], dim=0)
30-
# world_size * (2C + 1)
31-
combined_list = [
32-
torch.empty_like(combined) for k in range(world_size)
33-
]
34-
# Use allgather instead of allreduce since I don't trust in-place operations ..
35-
dist.all_gather(combined_list, combined, process_group, async_op=False)
36-
combined = torch.stack(combined_list, dim=0)
37-
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
38-
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
30+
# Use allgather instead of allreduce because count could be different across
31+
# ranks, simple all reduce op can not give correct results.
32+
# batch_norm_gather_stats_with_counts calculates global mean & invstd based on
33+
# all gathered mean, invstd and count.
34+
# for nccl backend, use the optimized version of all gather.
35+
if process_group._get_backend_name() == 'nccl':
36+
# world_size * (2C + 1)
37+
combined_size = combined.numel()
38+
combined_flat = torch.empty(1,
39+
combined_size * world_size,
40+
dtype=combined.dtype,
41+
device=combined.device)
42+
dist._all_gather_base(combined_flat, combined, process_group, async_op=False)
43+
combined = torch.reshape(combined_flat, (world_size, combined_size))
44+
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
45+
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
46+
else:
47+
# world_size * (2C + 1)
48+
combined_list = [
49+
torch.empty_like(combined) for k in range(world_size)
50+
]
51+
dist.all_gather(combined_list, combined, process_group, async_op=False)
52+
combined = torch.stack(combined_list, dim=0)
53+
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
54+
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
3955

4056
# calculate global mean & invstd
4157
mean, invstd = torch.batch_norm_gather_stats_with_counts(

torch/testing/_internal/distributed/distributed_test.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -6456,7 +6456,10 @@ def test_ddp_sync_bn_training_vs_eval(self):
64566456

64576457
# SyncBN allgathers stats across all ranks, so verify call to
64586458
# all_gather in profiler.
6459-
all_gather_calls = get_profiling_event("all_gather", prof)
6459+
if BACKEND == 'nccl':
6460+
all_gather_calls = get_profiling_event("_all_gather_base", prof)
6461+
else:
6462+
all_gather_calls = get_profiling_event("all_gather", prof)
64606463
self.assertNotEqual([], all_gather_calls)
64616464

64626465
# Only do inference on one rank. If SyncBN did collective stats sync,
@@ -6472,7 +6475,10 @@ def test_ddp_sync_bn_training_vs_eval(self):
64726475
loss.backward()
64736476

64746477
# Ensure sync does not occur in eval() mode.
6475-
all_gather_calls = get_profiling_event("all_gather", prof)
6478+
if BACKEND == 'nccl':
6479+
all_gather_calls = get_profiling_event("_all_gather_base", prof)
6480+
else:
6481+
all_gather_calls = get_profiling_event("all_gather", prof)
64766482
self.assertEqual([], all_gather_calls)
64776483

64786484
@skip_if_lt_x_gpu(2)

0 commit comments

Comments
 (0)