13
13
from typing import Any , Callable , Dict , Iterator , List , Mapping , Optional , Type
14
14
15
15
import torch
16
+ import torch .distributed as dist
16
17
from torch import nn , optim
17
18
from torch .nn .parameter import Parameter
18
19
from torch .optim .optimizer import Optimizer
19
20
from torch .utils .hooks import RemovableHandle
20
- import torch .distributed as dist
21
21
22
22
from torchft .manager import Manager
23
23
@@ -183,6 +183,8 @@ class DiLoCo(LocalSGD):
183
183
diloco: https://arxiv.org/pdf/2311.08105
184
184
"""
185
185
186
+ BUCKET_SIZE_BYTES = 32 * 1024 * 1024
187
+
186
188
def __init__ (
187
189
self ,
188
190
manager : Manager ,
@@ -192,6 +194,7 @@ def __init__(
192
194
sync_every : int ,
193
195
backup_device : Optional [torch .device ] = None ,
194
196
pin_memory : bool = True ,
197
+ use_bucketization = False ,
195
198
) -> None :
196
199
if manager ._use_async_quorum :
197
200
raise ValueError (
@@ -224,35 +227,67 @@ def _perform_sync(self) -> None:
224
227
self ._outer_optimizer .step ()
225
228
self ._save_parameters ()
226
229
self ._outer_optimizer .zero_grad ()
227
-
230
+
228
231
def _average_grads (self ) -> None :
229
232
"""
230
- Efficiently averages gradients across the diloco group using buffer-based bucketization.
233
+ Efficiently averages gradients across the group using either:
234
+ - Per-parameter allreduce (old behavior)
235
+ - Bucketized allreduce (new behavior)
231
236
"""
237
+ if self .use_bucketization :
238
+ self ._allreduce_bucketized ()
239
+ else :
240
+ self ._allreduce_per_param ()
232
241
233
- grads = [p .grad for p in self ._model .parameters () if p .grad is not None ]
242
+ def _allreduce_per_param (self ) -> None :
243
+ """Performs allreduce on each gradient tensor separately (original method)."""
244
+ works = []
245
+ for p in self ._model .parameters ():
246
+ if p .grad is None :
247
+ continue
248
+ work = self ._manager .allreduce (p .grad )
249
+ works .append (work )
250
+
251
+ for work in works :
252
+ work .wait ()
234
253
254
+ def _allreduce_bucketized (self ) -> None :
255
+ """
256
+ Averages gradients using bucketized allreduce with a fixed 32MB buffer.
257
+ """
258
+
259
+ grads = [p .grad for p in self ._model .parameters () if p .grad is not None ]
235
260
if not grads :
236
- return # No gradients to process
261
+ return
237
262
238
- # Compute total size and allocate a flat buffer for all gradients
263
+ # Compute total size and allocate a flat buffer
239
264
total_size = sum (g .numel () for g in grads )
240
- flat_buffer = torch . zeros ( total_size , dtype = grads [0 ].dtype , device = grads [0 ].device )
265
+ dtype , device = grads [0 ].dtype , grads [0 ].device
241
266
242
- # Pack gradients into the buffer
267
+ # Process in fixed 32MB chunks
243
268
offset = 0
244
- for g in grads :
245
- flat_buffer [offset : offset + g .numel ()].copy_ (g .view (- 1 ))
246
- offset += g .numel ()
269
+ while offset < total_size :
270
+ # Compute chunk size
271
+ chunk_size = min (
272
+ self .BUCKET_SIZE_BYTES // grads [0 ].element_size (), total_size - offset
273
+ )
247
274
248
- # Perform Allreduce on the entire buffer
249
- work = self ._manager .allreduce (flat_buffer )
275
+ flat_buffer = torch .zeros (chunk_size , dtype = dtype , device = device )
250
276
251
- # Wait for Allreduce to complete
252
- work .wait ()
277
+ # Pack gradients into buffer
278
+ pack_offset , bucket_tensors = 0 , []
279
+ for g in grads :
280
+ numel = g .numel ()
281
+ if pack_offset + numel > chunk_size :
282
+ break
283
+ flat_buffer [pack_offset : pack_offset + numel ].copy_ (g .view (- 1 ))
284
+ bucket_tensors .append ((g , pack_offset , numel ))
285
+ pack_offset += numel
253
286
254
- # Unpack gradients back into their original tensors
255
- offset = 0
256
- for g in grads :
257
- g .copy_ (flat_buffer [offset : offset + g .numel ()].view_as (g ))
258
- offset += g .numel ()
287
+ work = self ._manager .allreduce (flat_buffer )
288
+ work .wait ()
289
+
290
+ for g , pack_offset , numel in bucket_tensors :
291
+ g .copy_ (flat_buffer [pack_offset : pack_offset + numel ].view_as (g ))
292
+
293
+ offset += chunk_size # Move to next chunk
0 commit comments