Skip to content

Commit a695fcf

Browse files
yhcharlespytorchmergebot
authored andcommitted
Add tests for replicate multiple modules (pytorch#89099)
Pull Request resolved: pytorch#89099 Approved by: https://github.com/zhaojuanmao
1 parent 767f6aa commit a695fcf

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

Diff for: test/distributed/_composable/test_replicate.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,7 @@ def tearDown(self):
3939
except OSError:
4040
pass
4141

42-
def _prepare_module(self, global_batch_size):
43-
model = Net()
44-
input = torch.randn(global_batch_size, 2)
45-
target = torch.randn(global_batch_size, 4)
46-
return model, input, target
47-
48-
def test_replicate(self):
42+
def _compare_module(self, mod, replicate_mod):
4943
dist.init_process_group(
5044
backend="gloo",
5145
rank=self.rank,
@@ -55,8 +49,8 @@ def test_replicate(self):
5549

5650
local_batch_size = 1
5751
global_batch_size = self.world_size * local_batch_size
58-
model, input, target = self._prepare_module(global_batch_size)
59-
replicate_model = mark_root_module(replicate(deepcopy(model)))
52+
input = torch.randn(global_batch_size, 2)
53+
target = torch.randn(global_batch_size, 4)
6054

6155
def step_model(model, input, target):
6256
model.train()
@@ -69,9 +63,9 @@ def step_model(model, input, target):
6963
param.grad = None
7064

7165
for iteration in range(2):
72-
step_model(model, input, target)
66+
step_model(mod, input, target)
7367
step_model(
74-
replicate_model,
68+
replicate_mod,
7569
input[
7670
self.rank
7771
* local_batch_size : (self.rank + 1)
@@ -85,16 +79,29 @@ def step_model(model, input, target):
8579
)
8680

8781
self.assertEqual(
88-
len(list(model.parameters())),
89-
len(list(replicate_model.parameters())),
82+
len(list(mod.parameters())),
83+
len(list(replicate_mod.parameters())),
9084
)
91-
for i, j in zip(model.parameters(), replicate_model.parameters()):
85+
for i, j in zip(mod.parameters(), replicate_mod.parameters()):
9286
self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5)
9387

9488
# Shuffle the input so that DDP input is different
9589
torch.manual_seed(iteration)
9690
input = input[torch.randperm(global_batch_size)]
9791

92+
def test_replicate_single_module(self):
93+
model = Net()
94+
replicate_model = mark_root_module(replicate(deepcopy(model)))
95+
self._compare_module(model, replicate_model)
96+
97+
def test_replicate_multi_module(self):
98+
model = Net()
99+
replicate_model = mark_root_module(deepcopy(model))
100+
replicate(replicate_model.fc1)
101+
replicate(replicate_model.fc2)
102+
replicate(replicate_model.fc3)
103+
self._compare_module(model, replicate_model)
104+
98105

99106
if __name__ == "__main__":
100107
run_tests()

0 commit comments

Comments
 (0)