Skip to content

Commit d774938

Browse files
committed
Adding a flag and unit tests
Some issue while running the unit test cases, will look into it more.
1 parent 910eb3f commit d774938

File tree

2 files changed

+109
-20
lines changed

2 files changed

+109
-20
lines changed

torchft/local_sgd.py

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Type
1414

1515
import torch
16+
import torch.distributed as dist
1617
from torch import nn, optim
1718
from torch.nn.parameter import Parameter
1819
from torch.optim.optimizer import Optimizer
1920
from torch.utils.hooks import RemovableHandle
20-
import torch.distributed as dist
2121

2222
from torchft.manager import Manager
2323

@@ -183,6 +183,8 @@ class DiLoCo(LocalSGD):
183183
diloco: https://arxiv.org/pdf/2311.08105
184184
"""
185185

186+
BUCKET_SIZE_BYTES = 32 * 1024 * 1024
187+
186188
def __init__(
187189
self,
188190
manager: Manager,
@@ -192,6 +194,7 @@ def __init__(
192194
sync_every: int,
193195
backup_device: Optional[torch.device] = None,
194196
pin_memory: bool = True,
197+
use_bucketization=False,
195198
) -> None:
196199
if manager._use_async_quorum:
197200
raise ValueError(
@@ -224,35 +227,67 @@ def _perform_sync(self) -> None:
224227
self._outer_optimizer.step()
225228
self._save_parameters()
226229
self._outer_optimizer.zero_grad()
227-
230+
228231
def _average_grads(self) -> None:
229232
"""
230-
Efficiently averages gradients across the diloco group using buffer-based bucketization.
233+
Efficiently averages gradients across the group using either:
234+
- Per-parameter allreduce (old behavior)
235+
- Bucketized allreduce (new behavior)
231236
"""
237+
if self.use_bucketization:
238+
self._allreduce_bucketized()
239+
else:
240+
self._allreduce_per_param()
232241

233-
grads = [p.grad for p in self._model.parameters() if p.grad is not None]
242+
def _allreduce_per_param(self) -> None:
243+
"""Performs allreduce on each gradient tensor separately (original method)."""
244+
works = []
245+
for p in self._model.parameters():
246+
if p.grad is None:
247+
continue
248+
work = self._manager.allreduce(p.grad)
249+
works.append(work)
250+
251+
for work in works:
252+
work.wait()
234253

254+
def _allreduce_bucketized(self) -> None:
255+
"""
256+
Averages gradients using bucketized allreduce with a fixed 32MB buffer.
257+
"""
258+
259+
grads = [p.grad for p in self._model.parameters() if p.grad is not None]
235260
if not grads:
236-
return # No gradients to process
261+
return
237262

238-
# Compute total size and allocate a flat buffer for all gradients
263+
# Compute total size and allocate a flat buffer
239264
total_size = sum(g.numel() for g in grads)
240-
flat_buffer = torch.zeros(total_size, dtype=grads[0].dtype, device=grads[0].device)
265+
dtype, device = grads[0].dtype, grads[0].device
241266

242-
# Pack gradients into the buffer
267+
# Process in fixed 32MB chunks
243268
offset = 0
244-
for g in grads:
245-
flat_buffer[offset : offset + g.numel()].copy_(g.view(-1))
246-
offset += g.numel()
269+
while offset < total_size:
270+
# Compute chunk size
271+
chunk_size = min(
272+
self.BUCKET_SIZE_BYTES // grads[0].element_size(), total_size - offset
273+
)
247274

248-
# Perform Allreduce on the entire buffer
249-
work = self._manager.allreduce(flat_buffer)
275+
flat_buffer = torch.zeros(chunk_size, dtype=dtype, device=device)
250276

251-
# Wait for Allreduce to complete
252-
work.wait()
277+
# Pack gradients into buffer
278+
pack_offset, bucket_tensors = 0, []
279+
for g in grads:
280+
numel = g.numel()
281+
if pack_offset + numel > chunk_size:
282+
break
283+
flat_buffer[pack_offset : pack_offset + numel].copy_(g.view(-1))
284+
bucket_tensors.append((g, pack_offset, numel))
285+
pack_offset += numel
253286

254-
# Unpack gradients back into their original tensors
255-
offset = 0
256-
for g in grads:
257-
g.copy_(flat_buffer[offset : offset + g.numel()].view_as(g))
258-
offset += g.numel()
287+
work = self._manager.allreduce(flat_buffer)
288+
work.wait()
289+
290+
for g, pack_offset, numel in bucket_tensors:
291+
g.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(g))
292+
293+
offset += chunk_size # Move to next chunk

torchft/local_sgd_test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,57 @@ def test_diloco_healthy(self) -> None:
144144

145145
outer_opt_state = outer_optimizer.state_dict()
146146
self.assertEqual(len(outer_opt_state["state"]), parameter_count)
147+
148+
def test_diloco_without_bucketization(self):
149+
model = SimpleModel()
150+
inner_optimizer = optim.AdamW(
151+
model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
152+
)
153+
outer_optimizer = optim.SGD(
154+
model.parameters(), lr=0.7, momentum=0.9, nesterov=True
155+
)
156+
manager = create_autospec(Manager)
157+
manager._use_async_quorum = False
158+
159+
with DiLoCo(
160+
manager,
161+
model,
162+
inner_optimizer,
163+
outer_optimizer,
164+
sync_every=2,
165+
use_bucketization=False,
166+
) as diloco:
167+
inp = torch.rand(2, 3)
168+
loss = model(inp).mean()
169+
loss.backward()
170+
inner_optimizer.step()
171+
self.assertEqual(diloco._local_step, 1)
172+
self.assertEqual(
173+
manager.allreduce.call_count, len(list(model.parameters()))
174+
)
175+
176+
def test_diloco_with_bucketization(self):
177+
model = SimpleModel()
178+
inner_optimizer = optim.AdamW(
179+
model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
180+
)
181+
outer_optimizer = optim.SGD(
182+
model.parameters(), lr=0.7, momentum=0.9, nesterov=True
183+
)
184+
manager = create_autospec(Manager)
185+
manager._use_async_quorum = False
186+
187+
with DiLoCo(
188+
manager,
189+
model,
190+
inner_optimizer,
191+
outer_optimizer,
192+
sync_every=2,
193+
use_bucketization=True,
194+
) as diloco:
195+
inp = torch.rand(2, 3)
196+
loss = model(inp).mean()
197+
loss.backward()
198+
inner_optimizer.step()
199+
self.assertEqual(diloco._local_step, 1)
200+
self.assertGreaterEqual(manager.allreduce.call_count, 1)

0 commit comments

Comments
 (0)