Skip to content

Commit 0a5bc89

Browse files
authored
Implementing bucketized model averaging for LocalSGD (#111)
* Implementing bucketized model averaging for LocalSGD * Adding a flag and unit tests Some issue while running the unit test cases, will look into it more. * Updates based on comments 1) Fixed lint issues 2) Changed variable name for bucket size 3) Added parameterised unit test * Update local_sgd.py * Update local_sgd.py * New updates 1) Created new test to check correctness of bucketization logic. 2) Separated the bucketizing algorithm into a dedicated function * Fixing Pyre issues
1 parent 0691f80 commit 0a5bc89

File tree

2 files changed

+172
-3
lines changed

2 files changed

+172
-3
lines changed

torchft/local_sgd.py

+77-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type
1414

1515
import torch
16+
import torch.distributed as dist
1617
from torch import nn, optim
1718
from torch.distributed.tensor import DTensor
1819
from torch.nn.parameter import Parameter
@@ -166,6 +167,9 @@ class DiLoCo:
166167
DiLoCo paper: https://arxiv.org/pdf/2311.08105
167168
"""
168169

170+
bucket_cap_mb: int = 32 * 1024 * 1024
171+
use_bucketization: bool = False
172+
169173
def __init__(
170174
self,
171175
manager: Manager,
@@ -175,6 +179,8 @@ def __init__(
175179
sync_every: int,
176180
backup_device: Optional[torch.device] = None,
177181
pin_memory: bool = True,
182+
use_bucketization: bool = False,
183+
bucket_cap_mb: Optional[int] = None,
178184
) -> None:
179185
"""
180186
Args:
@@ -204,6 +210,12 @@ def __init__(
204210

205211
self._hooks: List[RemovableHandle] = []
206212
self._outer_optimizer = outer_optimizer
213+
214+
if bucket_cap_mb is not None:
215+
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)
216+
217+
self.use_bucketization = use_bucketization
218+
207219
self.original_parameters: Dict[str, torch.Tensor] = {}
208220
for name, p in self._model.named_parameters():
209221
if isinstance(p, DTensor):
@@ -308,8 +320,17 @@ def _perform_sync(self) -> None:
308320

309321
def _average_grads(self) -> None:
310322
"""
311-
Average the gradients across the diloco group.
323+
Efficiently averages gradients across the group using either:
324+
- Per-parameter allreduce (old behavior)
325+
- Bucketized allreduce (new behavior)
312326
"""
327+
if self.use_bucketization:
328+
self._allreduce_bucketized()
329+
else:
330+
self._allreduce_per_param()
331+
332+
def _allreduce_per_param(self) -> None:
333+
"""Performs allreduce on each gradient tensor separately (original method)."""
313334
works = []
314335
for p in self._model.parameters():
315336
# Perform allreduce on the pseudogradients
@@ -319,6 +340,60 @@ def _average_grads(self) -> None:
319340
else:
320341
work = self._manager.allreduce(p.grad)
321342
works.append(work)
322-
# Wait for all allreduce operations to complete
343+
323344
for work in works:
324345
work.wait()
346+
347+
def bucketize_and_allreduce(
348+
self,
349+
tensors: List[torch.Tensor],
350+
bucket_size_bytes: int,
351+
) -> None:
352+
"""
353+
Applies allreduce on a list of tensors using bucketization.
354+
355+
Args:
356+
tensors: List of torch tensors (e.g., gradients).
357+
bucket_size_bytes: Max size of each bucket in bytes.
358+
"""
359+
if not tensors:
360+
return
361+
362+
total_size = sum(t.numel() for t in tensors)
363+
dtype, device = tensors[0].dtype, tensors[0].device
364+
365+
offset = 0
366+
flat_index = 0
367+
while offset < total_size:
368+
chunk_size = min(
369+
bucket_size_bytes // tensors[0].element_size(), total_size - offset
370+
)
371+
flat_buffer = torch.zeros(chunk_size, dtype=dtype, device=device)
372+
373+
pack_offset, bucket_tensors = 0, []
374+
for t in tensors[flat_index:]:
375+
numel = t.numel()
376+
if pack_offset + numel > chunk_size:
377+
break
378+
flat_buffer[pack_offset : pack_offset + numel].copy_(t.view(-1))
379+
bucket_tensors.append((t, pack_offset, numel))
380+
pack_offset += numel
381+
flat_index += 1
382+
383+
work = self._manager.allreduce(flat_buffer)
384+
work.wait()
385+
386+
for t, pack_offset, numel in bucket_tensors:
387+
t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t))
388+
389+
offset += chunk_size
390+
391+
def _allreduce_bucketized(self) -> None:
392+
"""
393+
Averages gradients using bucketized allreduce with a fixed buffer.
394+
"""
395+
grads = [p.grad for p in self._model.parameters() if p.grad is not None]
396+
self.bucketize_and_allreduce(
397+
grads,
398+
bucket_size_bytes=self.bucket_cap_mb,
399+
)

torchft/local_sgd_test.py

+95-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from unittest.mock import MagicMock, create_autospec
1010

1111
import torch
12-
from torch import nn, optim
12+
from parameterized import parameterized
13+
from torch import Tensor, nn, optim
1314
from torch.distributed.tensor import DTensor
1415

1516
from torchft.local_sgd import DiLoCo, LocalSGD, extract_local_tensor
@@ -147,3 +148,96 @@ def test_diloco_healthy(self) -> None:
147148

148149
outer_opt_state = outer_optimizer.state_dict()
149150
self.assertEqual(len(outer_opt_state["state"]), parameter_count)
151+
152+
@parameterized.expand(
153+
[
154+
("bucketized_should_use_fewer_calls", True, True),
155+
("non_bucketized_should_call_per_param", False, False),
156+
]
157+
)
158+
def test_diloco_allreduce_call_efficiency(
159+
self,
160+
name: str,
161+
use_bucketization: bool,
162+
expect_fewer_calls: bool,
163+
) -> None:
164+
model = SimpleModel()
165+
166+
inner_optimizer = torch.optim.AdamW(
167+
model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
168+
)
169+
outer_optimizer = torch.optim.SGD(
170+
model.parameters(), lr=0.7, momentum=0.9, nesterov=True
171+
)
172+
173+
manager = create_autospec(Manager)
174+
manager._use_async_quorum = False
175+
manager.should_commit.return_value = True
176+
177+
with DiLoCo(
178+
manager,
179+
model,
180+
inner_optimizer,
181+
outer_optimizer,
182+
sync_every=2,
183+
use_bucketization=use_bucketization,
184+
) as diloco:
185+
inp = torch.rand(2, 3)
186+
loss = model(inp).mean()
187+
loss.backward()
188+
inner_optimizer.step()
189+
190+
loss = model(inp).mean()
191+
loss.backward()
192+
inner_optimizer.step()
193+
194+
allreduce_calls = manager.allreduce.call_count
195+
param_count = len([p for p in model.parameters() if p.requires_grad])
196+
197+
if expect_fewer_calls:
198+
self.assertLess(int(allreduce_calls), int(param_count))
199+
else:
200+
self.assertEqual(int(allreduce_calls), int(param_count))
201+
202+
def test_bucketization_correctness(self) -> None:
203+
class TinyModel(nn.Module):
204+
def __init__(self):
205+
super().__init__()
206+
self.w1 = nn.Parameter(torch.tensor([1.0, 2.0]))
207+
self.w2 = nn.Parameter(torch.tensor([3.0, 4.0, 5.0]))
208+
209+
def forward(self, x):
210+
return x @ self.w1.unsqueeze(0).T + self.w2.sum()
211+
212+
model = TinyModel()
213+
inner_opt = torch.optim.SGD(model.parameters(), lr=0.1)
214+
outer_opt = torch.optim.SGD(model.parameters(), lr=0.1)
215+
216+
# Manually assign fake gradients
217+
grads = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0, 5.0])]
218+
for p, g in zip(model.parameters(), grads):
219+
p.grad = g.clone()
220+
221+
manager = create_autospec(Manager)
222+
manager._use_async_quorum = False
223+
manager.should_commit.return_value = True
224+
225+
# Define fake allreduce: multiplies buffer by 2
226+
def fake_allreduce(tensor: Tensor) -> MagicMock:
227+
tensor.mul_(2)
228+
return MagicMock(wait=lambda: None)
229+
230+
manager.allreduce.side_effect = fake_allreduce
231+
232+
diloco = DiLoCo(
233+
manager, model, inner_opt, outer_opt, sync_every=2, use_bucketization=True
234+
)
235+
diloco.bucket_cap_mb = 10 * 1024 * 1024
236+
237+
# Run only bucketized logic
238+
diloco._average_grads()
239+
240+
# Expect grads to have been doubled
241+
expected_grads = [g * 2 for g in grads]
242+
for param, expected in zip(model.parameters(), expected_grads):
243+
torch.testing.assert_close(param.grad, expected, rtol=1e-5, atol=1e-8)

0 commit comments

Comments
 (0)