9
9
from unittest .mock import create_autospec
10
10
11
11
import torch
12
+ from parameterized import parameterized
12
13
from torch import nn , optim
13
14
14
15
from torchft .local_sgd import DiLoCo , LocalSGD
@@ -145,42 +146,33 @@ def test_diloco_healthy(self) -> None:
145
146
outer_opt_state = outer_optimizer .state_dict ()
146
147
self .assertEqual (len (outer_opt_state ["state" ]), parameter_count )
147
148
148
- def test_diloco_without_bucketization (self ):
149
+ @parameterized .expand (
150
+ [
151
+ (
152
+ "without_bucketization" ,
153
+ False ,
154
+ lambda self , manager , model : self .assertEqual (
155
+ manager .allreduce .call_count , len (list (model .parameters ()))
156
+ ),
157
+ ),
158
+ (
159
+ "with_bucketization" ,
160
+ True ,
161
+ lambda self , manager , model : self .assertGreaterEqual (
162
+ manager .allreduce .call_count , 1
163
+ ),
164
+ ),
165
+ ]
166
+ )
167
+ def test_diloco_all_reduce (self , name , use_bucketization , assert_func ):
149
168
model = SimpleModel ()
150
169
inner_optimizer = optim .AdamW (
151
170
model .parameters (), lr = 4e-4 , weight_decay = 0.1 , betas = (0.9 , 0.95 )
152
171
)
153
172
outer_optimizer = optim .SGD (
154
173
model .parameters (), lr = 0.7 , momentum = 0.9 , nesterov = True
155
174
)
156
- manager = create_autospec (Manager )
157
- manager ._use_async_quorum = False
158
-
159
- with DiLoCo (
160
- manager ,
161
- model ,
162
- inner_optimizer ,
163
- outer_optimizer ,
164
- sync_every = 2 ,
165
- use_bucketization = False ,
166
- ) as diloco :
167
- inp = torch .rand (2 , 3 )
168
- loss = model (inp ).mean ()
169
- loss .backward ()
170
- inner_optimizer .step ()
171
- self .assertEqual (diloco ._local_step , 1 )
172
- self .assertEqual (
173
- manager .allreduce .call_count , len (list (model .parameters ()))
174
- )
175
175
176
- def test_diloco_with_bucketization (self ):
177
- model = SimpleModel ()
178
- inner_optimizer = optim .AdamW (
179
- model .parameters (), lr = 4e-4 , weight_decay = 0.1 , betas = (0.9 , 0.95 )
180
- )
181
- outer_optimizer = optim .SGD (
182
- model .parameters (), lr = 0.7 , momentum = 0.9 , nesterov = True
183
- )
184
176
manager = create_autospec (Manager )
185
177
manager ._use_async_quorum = False
186
178
@@ -190,11 +182,12 @@ def test_diloco_with_bucketization(self):
190
182
inner_optimizer ,
191
183
outer_optimizer ,
192
184
sync_every = 2 ,
193
- use_bucketization = True ,
185
+ use_bucketization = use_bucketization ,
194
186
) as diloco :
195
187
inp = torch .rand (2 , 3 )
196
188
loss = model (inp ).mean ()
197
189
loss .backward ()
198
190
inner_optimizer .step ()
191
+
199
192
self .assertEqual (diloco ._local_step , 1 )
200
- self . assertGreaterEqual ( manager . allreduce . call_count , 1 )
193
+ assert_func ( self , manager , model )
0 commit comments