17
17
from torch .nn .parameter import Parameter
18
18
from torch .optim .optimizer import Optimizer
19
19
from torch .utils .hooks import RemovableHandle
20
+ import torch .distributed as dist
20
21
21
22
from torchft .manager import Manager
22
23
@@ -223,17 +224,35 @@ def _perform_sync(self) -> None:
223
224
self ._outer_optimizer .step ()
224
225
self ._save_parameters ()
225
226
self ._outer_optimizer .zero_grad ()
226
-
227
+
227
228
def _average_grads (self ) -> None :
228
229
"""
229
- Average the gradients across the diloco group.
230
+ Efficiently averages gradients across the diloco group using buffer-based bucketization .
230
231
"""
231
- works = []
232
- for p in self ._model .parameters ():
233
- # Perform allreduce on the pseudogradients
234
- assert p .grad is not None
235
- work = self ._manager .allreduce (p .grad )
236
- works .append (work )
237
- # Wait for all allreduce operations to complete
238
- for work in works :
239
- work .wait ()
232
+
233
+ grads = [p .grad for p in self ._model .parameters () if p .grad is not None ]
234
+
235
+ if not grads :
236
+ return # No gradients to process
237
+
238
+ # Compute total size and allocate a flat buffer for all gradients
239
+ total_size = sum (g .numel () for g in grads )
240
+ flat_buffer = torch .zeros (total_size , dtype = grads [0 ].dtype , device = grads [0 ].device )
241
+
242
+ # Pack gradients into the buffer
243
+ offset = 0
244
+ for g in grads :
245
+ flat_buffer [offset : offset + g .numel ()].copy_ (g .view (- 1 ))
246
+ offset += g .numel ()
247
+
248
+ # Perform Allreduce on the entire buffer
249
+ work = self ._manager .allreduce (flat_buffer )
250
+
251
+ # Wait for Allreduce to complete
252
+ work .wait ()
253
+
254
+ # Unpack gradients back into their original tensors
255
+ offset = 0
256
+ for g in grads :
257
+ g .copy_ (flat_buffer [offset : offset + g .numel ()].view_as (g ))
258
+ offset += g .numel ()
0 commit comments