@@ -39,13 +39,7 @@ def tearDown(self):
39
39
except OSError :
40
40
pass
41
41
42
- def _prepare_module (self , global_batch_size ):
43
- model = Net ()
44
- input = torch .randn (global_batch_size , 2 )
45
- target = torch .randn (global_batch_size , 4 )
46
- return model , input , target
47
-
48
- def test_replicate (self ):
42
+ def _compare_module (self , mod , replicate_mod ):
49
43
dist .init_process_group (
50
44
backend = "gloo" ,
51
45
rank = self .rank ,
@@ -55,8 +49,8 @@ def test_replicate(self):
55
49
56
50
local_batch_size = 1
57
51
global_batch_size = self .world_size * local_batch_size
58
- model , input , target = self . _prepare_module (global_batch_size )
59
- replicate_model = mark_root_module ( replicate ( deepcopy ( model )) )
52
+ input = torch . randn (global_batch_size , 2 )
53
+ target = torch . randn ( global_batch_size , 4 )
60
54
61
55
def step_model (model , input , target ):
62
56
model .train ()
@@ -69,9 +63,9 @@ def step_model(model, input, target):
69
63
param .grad = None
70
64
71
65
for iteration in range (2 ):
72
- step_model (model , input , target )
66
+ step_model (mod , input , target )
73
67
step_model (
74
- replicate_model ,
68
+ replicate_mod ,
75
69
input [
76
70
self .rank
77
71
* local_batch_size : (self .rank + 1 )
@@ -85,16 +79,29 @@ def step_model(model, input, target):
85
79
)
86
80
87
81
self .assertEqual (
88
- len (list (model .parameters ())),
89
- len (list (replicate_model .parameters ())),
82
+ len (list (mod .parameters ())),
83
+ len (list (replicate_mod .parameters ())),
90
84
)
91
- for i , j in zip (model .parameters (), replicate_model .parameters ()):
85
+ for i , j in zip (mod .parameters (), replicate_mod .parameters ()):
92
86
self .assertEqual (i , j , rtol = 1.3e-06 , atol = 5e-5 )
93
87
94
88
# Shuffle the input so that DDP input is different
95
89
torch .manual_seed (iteration )
96
90
input = input [torch .randperm (global_batch_size )]
97
91
92
+ def test_replicate_single_module (self ):
93
+ model = Net ()
94
+ replicate_model = mark_root_module (replicate (deepcopy (model )))
95
+ self ._compare_module (model , replicate_model )
96
+
97
+ def test_replicate_multi_module (self ):
98
+ model = Net ()
99
+ replicate_model = mark_root_module (deepcopy (model ))
100
+ replicate (replicate_model .fc1 )
101
+ replicate (replicate_model .fc2 )
102
+ replicate (replicate_model .fc3 )
103
+ self ._compare_module (model , replicate_model )
104
+
98
105
99
106
if __name__ == "__main__" :
100
107
run_tests ()
0 commit comments