Skip to content

Commit d120b9a

Browse files
teng-lifacebook-github-bot
authored andcommitted
Make c10d pickling/unpickling work (pytorch#12694)
Summary: This fixes the issue for pytorch#12168 Pull Request resolved: pytorch#12694 Differential Revision: D10468717 Pulled By: teng-li fbshipit-source-id: 3df31d75eea19d6085af665f5350d3cb667a5048
1 parent 8cb0848 commit d120b9a

File tree

2 files changed

+59
-30
lines changed

2 files changed

+59
-30
lines changed

test/test_distributed.py

+28-19
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,29 @@
2828
DEFAULT_TIMEOUT = 300
2929
CUSTOMIZED_TIMEOUT = {"test_DistributedDataParallel": 500}
3030

31+
3132
if INIT_METHOD.startswith("file://"):
3233
FOLDER = INIT_METHOD[7:]
3334

3435

36+
class Net(nn.Module):
37+
def __init__(self):
38+
super(Net, self).__init__()
39+
self.fc1 = nn.Linear(2, 10, bias=False)
40+
self.fc2 = nn.Linear(10, 50, bias=False)
41+
self.fc3 = nn.Linear(50, 4, bias=False)
42+
self.relu = nn.ReLU()
43+
44+
def forward(self, x):
45+
x = self.relu(self.fc1(x))
46+
x = self.relu(self.fc2(x))
47+
x = self.fc3(x)
48+
return F.softmax(x, dim=1)
49+
50+
51+
DDP_NET = Net()
52+
53+
3554
def get_timeout(test_id):
3655
test_name = test_id.split(".")[-1]
3756
if test_name in CUSTOMIZED_TIMEOUT:
@@ -44,6 +63,7 @@ def get_timeout(test_id):
4463
print("Distributed not available, skipping tests")
4564
sys.exit(0)
4665

66+
4767
SKIP_IF_NO_CUDA_EXIT_CODE = 75
4868
SKIP_IF_NO_GPU_EXIT_CODE = 76
4969
SKIP_IF_SMALL_WORLDSIZE_EXIT_CODE = 77
@@ -1109,23 +1129,6 @@ def test_all_gather_multigpu(self):
11091129
rank_to_GPU = self._init_multigpu_helper()
11101130
self._test_all_gather_multigpu_helper(group, group_id, rank, rank_to_GPU)
11111131

1112-
def _create_Net(self):
1113-
class Net(nn.Module):
1114-
def __init__(self):
1115-
super(Net, self).__init__()
1116-
self.fc1 = nn.Linear(2, 10, bias=False)
1117-
self.fc2 = nn.Linear(10, 50, bias=False)
1118-
self.fc3 = nn.Linear(50, 4, bias=False)
1119-
self.relu = nn.ReLU()
1120-
1121-
def forward(self, x):
1122-
x = self.relu(self.fc1(x))
1123-
x = self.relu(self.fc2(x))
1124-
x = self.fc3(x)
1125-
return F.softmax(x, dim=1)
1126-
1127-
return Net()
1128-
11291132
def _model_step(self, model):
11301133
for param in model.parameters():
11311134
param.data += param.grad
@@ -1193,7 +1196,7 @@ def _test_DistributedDataParallel(self, gpu_subset, rank, output_device=None):
11931196
# as baseline
11941197

11951198
# cpu training setup
1196-
model = self._create_Net()
1199+
model = DDP_NET
11971200

11981201
# single gpu training setup
11991202
model_gpu = copy.deepcopy(model)
@@ -1206,6 +1209,12 @@ def _test_DistributedDataParallel(self, gpu_subset, rank, output_device=None):
12061209
model_DDP, device_ids=gpu_subset
12071210
)
12081211

1212+
# test serializable/unserializable
1213+
if INIT_METHOD.startswith("file://"):
1214+
_, filename = tempfile.mkstemp(prefix=FOLDER)
1215+
torch.save(model_DDP, filename)
1216+
model_DDP = torch.load(filename)
1217+
12091218
# dummy data initialization
12101219
local_bs = len(gpu_subset)
12111220
global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)
@@ -1232,7 +1241,7 @@ def test_DistributedDataParallelCPU(self):
12321241
group, group_id, rank = self._init_global_test()
12331242

12341243
# cpu training setup
1235-
model_base = self._create_Net()
1244+
model_base = DDP_NET
12361245

12371246
# DDP-CPU training setup
12381247
model_DDP = copy.deepcopy(model_base)

torch/nn/parallel/distributed.py

+31-11
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,6 @@ def __init__(self, module, device_ids=None,
138138
self.output_device = _get_device_index(output_device, True)
139139
self.broadcast_buffers = broadcast_buffers
140140

141-
self.allreduce_opts = dist.AllreduceOptions()
142-
143141
MB = 1024 * 1024
144142

145143
# used for intra-node param sync and inter-node sync as well
@@ -207,26 +205,39 @@ def __init__(self, module, device_ids=None,
207205
self.next_bucket = len(self.bucket_sizes) - 1
208206
self.ready_buckets_not_reduced = set()
209207
self.reduction_works = [None for _ in range(len(self.bucket_sizes))]
210-
211208
self.devs_ready = [0 for _ in range(len(self.bucket_sizes))]
212-
213-
# default stream tracking to launch nccl reduce kernels
214-
self.default_streams = []
215-
for dev_id in self.device_ids:
216-
with torch.cuda.device(dev_id):
217-
self.default_streams.append(torch.cuda.current_stream())
218-
219209
self._register_grad_hooks()
220210

221211
def __getstate__(self):
212+
self._check_default_group()
222213
attrs = copy.copy(self.__dict__)
223-
del attrs['_grad_accs']
214+
del attrs['process_group'], \
215+
attrs['allreduce_opts'], \
216+
attrs['default_streams'], \
217+
attrs['_grad_accs']
224218
return attrs
225219

226220
def __setstate__(self, state):
221+
# If serializable, then the process group should be the default one
222+
self.process_group = dist.get_default_group()
227223
super(DistributedDataParallel, self).__setstate__(state)
228224
self._register_grad_hooks()
229225

226+
def _check_default_group(self):
227+
pickle_not_supported = False
228+
try:
229+
if self.process_group != dist.get_default_group():
230+
pickle_not_supported = True
231+
except RuntimeError:
232+
pickle_not_supported = True
233+
234+
if pickle_not_supported:
235+
raise RuntimeError("DDP Pickling/Unpickling are only supported "
236+
"when using DDP with the default process "
237+
"group. That is, when you have called "
238+
"init_process_group and have not passed "
239+
"process_group argument to DDP constructor")
240+
230241
def forward(self, *inputs, **kwargs):
231242
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
232243
self._sync_params()
@@ -279,6 +290,15 @@ def _sync_params(self):
279290

280291
def _register_grad_hooks(self):
281292
self._grad_accs = [] # need to keep them in scope
293+
294+
# default stream tracking to launch nccl reduce kernels
295+
self.default_streams = []
296+
for dev_id in self.device_ids:
297+
with torch.cuda.device(dev_id):
298+
self.default_streams.append(torch.cuda.current_stream())
299+
300+
self.allreduce_opts = dist.AllreduceOptions()
301+
282302
for device_idx, module in enumerate(self._module_copies):
283303
for p in module.parameters():
284304
if p.requires_grad:

0 commit comments

Comments
 (0)