Skip to content

Commit 910eb3f

Browse files
committed
Implementing bucketized model averaging for LocalSGD
1 parent 8628a3f commit 910eb3f

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

torchft/local_sgd.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torch.nn.parameter import Parameter
1818
from torch.optim.optimizer import Optimizer
1919
from torch.utils.hooks import RemovableHandle
20+
import torch.distributed as dist
2021

2122
from torchft.manager import Manager
2223

@@ -223,17 +224,35 @@ def _perform_sync(self) -> None:
223224
self._outer_optimizer.step()
224225
self._save_parameters()
225226
self._outer_optimizer.zero_grad()
226-
227+
227228
def _average_grads(self) -> None:
228229
"""
229-
Average the gradients across the diloco group.
230+
Efficiently averages gradients across the diloco group using buffer-based bucketization.
230231
"""
231-
works = []
232-
for p in self._model.parameters():
233-
# Perform allreduce on the pseudogradients
234-
assert p.grad is not None
235-
work = self._manager.allreduce(p.grad)
236-
works.append(work)
237-
# Wait for all allreduce operations to complete
238-
for work in works:
239-
work.wait()
232+
233+
grads = [p.grad for p in self._model.parameters() if p.grad is not None]
234+
235+
if not grads:
236+
return # No gradients to process
237+
238+
# Compute total size and allocate a flat buffer for all gradients
239+
total_size = sum(g.numel() for g in grads)
240+
flat_buffer = torch.zeros(total_size, dtype=grads[0].dtype, device=grads[0].device)
241+
242+
# Pack gradients into the buffer
243+
offset = 0
244+
for g in grads:
245+
flat_buffer[offset : offset + g.numel()].copy_(g.view(-1))
246+
offset += g.numel()
247+
248+
# Perform Allreduce on the entire buffer
249+
work = self._manager.allreduce(flat_buffer)
250+
251+
# Wait for Allreduce to complete
252+
work.wait()
253+
254+
# Unpack gradients back into their original tensors
255+
offset = 0
256+
for g in grads:
257+
g.copy_(flat_buffer[offset : offset + g.numel()].view_as(g))
258+
offset += g.numel()

0 commit comments

Comments
 (0)