13
13
from typing import Any , Callable , Dict , Iterator , List , Optional , Tuple , Type
14
14
15
15
import torch
16
+ import torch .distributed as dist
16
17
from torch import nn , optim
17
18
from torch .distributed .tensor import DTensor
18
19
from torch .nn .parameter import Parameter
@@ -166,6 +167,9 @@ class DiLoCo:
166
167
DiLoCo paper: https://arxiv.org/pdf/2311.08105
167
168
"""
168
169
170
+ bucket_cap_mb : int = 32 * 1024 * 1024
171
+ use_bucketization : bool = False
172
+
169
173
def __init__ (
170
174
self ,
171
175
manager : Manager ,
@@ -175,6 +179,8 @@ def __init__(
175
179
sync_every : int ,
176
180
backup_device : Optional [torch .device ] = None ,
177
181
pin_memory : bool = True ,
182
+ use_bucketization : bool = False ,
183
+ bucket_cap_mb : Optional [int ] = None ,
178
184
) -> None :
179
185
"""
180
186
Args:
@@ -204,6 +210,12 @@ def __init__(
204
210
205
211
self ._hooks : List [RemovableHandle ] = []
206
212
self ._outer_optimizer = outer_optimizer
213
+
214
+ if bucket_cap_mb is not None :
215
+ self .bucket_cap_mb = int (bucket_cap_mb * 1024 * 1024 )
216
+
217
+ self .use_bucketization = use_bucketization
218
+
207
219
self .original_parameters : Dict [str , torch .Tensor ] = {}
208
220
for name , p in self ._model .named_parameters ():
209
221
if isinstance (p , DTensor ):
@@ -308,8 +320,17 @@ def _perform_sync(self) -> None:
308
320
309
321
def _average_grads (self ) -> None :
310
322
"""
311
- Average the gradients across the diloco group.
323
+ Efficiently averages gradients across the group using either:
324
+ - Per-parameter allreduce (old behavior)
325
+ - Bucketized allreduce (new behavior)
312
326
"""
327
+ if self .use_bucketization :
328
+ self ._allreduce_bucketized ()
329
+ else :
330
+ self ._allreduce_per_param ()
331
+
332
+ def _allreduce_per_param (self ) -> None :
333
+ """Performs allreduce on each gradient tensor separately (original method)."""
313
334
works = []
314
335
for p in self ._model .parameters ():
315
336
# Perform allreduce on the pseudogradients
@@ -319,6 +340,60 @@ def _average_grads(self) -> None:
319
340
else :
320
341
work = self ._manager .allreduce (p .grad )
321
342
works .append (work )
322
- # Wait for all allreduce operations to complete
343
+
323
344
for work in works :
324
345
work .wait ()
346
+
347
+ def bucketize_and_allreduce (
348
+ self ,
349
+ tensors : List [torch .Tensor ],
350
+ bucket_size_bytes : int ,
351
+ ) -> None :
352
+ """
353
+ Applies allreduce on a list of tensors using bucketization.
354
+
355
+ Args:
356
+ tensors: List of torch tensors (e.g., gradients).
357
+ bucket_size_bytes: Max size of each bucket in bytes.
358
+ """
359
+ if not tensors :
360
+ return
361
+
362
+ total_size = sum (t .numel () for t in tensors )
363
+ dtype , device = tensors [0 ].dtype , tensors [0 ].device
364
+
365
+ offset = 0
366
+ flat_index = 0
367
+ while offset < total_size :
368
+ chunk_size = min (
369
+ bucket_size_bytes // tensors [0 ].element_size (), total_size - offset
370
+ )
371
+ flat_buffer = torch .zeros (chunk_size , dtype = dtype , device = device )
372
+
373
+ pack_offset , bucket_tensors = 0 , []
374
+ for t in tensors [flat_index :]:
375
+ numel = t .numel ()
376
+ if pack_offset + numel > chunk_size :
377
+ break
378
+ flat_buffer [pack_offset : pack_offset + numel ].copy_ (t .view (- 1 ))
379
+ bucket_tensors .append ((t , pack_offset , numel ))
380
+ pack_offset += numel
381
+ flat_index += 1
382
+
383
+ work = self ._manager .allreduce (flat_buffer )
384
+ work .wait ()
385
+
386
+ for t , pack_offset , numel in bucket_tensors :
387
+ t .copy_ (flat_buffer [pack_offset : pack_offset + numel ].view_as (t ))
388
+
389
+ offset += chunk_size
390
+
391
+ def _allreduce_bucketized (self ) -> None :
392
+ """
393
+ Averages gradients using bucketized allreduce with a fixed buffer.
394
+ """
395
+ grads = [p .grad for p in self ._model .parameters () if p .grad is not None ]
396
+ self .bucketize_and_allreduce (
397
+ grads ,
398
+ bucket_size_bytes = self .bucket_cap_mb ,
399
+ )
0 commit comments