Skip to content

Implementing bucketized model averaging for LocalSGD #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
86 changes: 82 additions & 4 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type

import torch
import torch.distributed as dist
from torch import nn, optim
from torch.nn.parameter import Parameter
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -153,6 +154,9 @@ class DiLoCo:
diloco: https://arxiv.org/pdf/2311.08105
"""

bucket_cap_mb = 32 * 1024 * 1024
use_bucketization = False

def __init__(
self,
manager: Manager,
Expand All @@ -162,6 +166,8 @@ def __init__(
sync_every: int,
backup_device: Optional[torch.device] = None,
pin_memory: bool = True,
use_bucketization: bool = False,
bucket_cap_mb: int = None,
) -> None:
if manager._use_async_quorum:
raise ValueError(
Expand All @@ -180,6 +186,12 @@ def __init__(

self._hooks: List[RemovableHandle] = []
self._outer_optimizer = outer_optimizer

if bucket_cap_mb is not None:
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)

self.use_bucketization = use_bucketization

self.original_parameters: Dict[str, torch.Tensor] = {}
for name, p in self._model.named_parameters():
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=self._backup_device)
Expand Down Expand Up @@ -266,14 +278,80 @@ def _perform_sync(self) -> None:

def _average_grads(self) -> None:
"""
Average the gradients across the diloco group.
Efficiently averages gradients across the group using either:
- Per-parameter allreduce (old behavior)
- Bucketized allreduce (new behavior)
"""
if self.use_bucketization:
self._allreduce_bucketized()
else:
self._allreduce_per_param()

def _allreduce_per_param(self) -> None:
"""Performs allreduce on each gradient tensor separately (original method)."""
works = []
for p in self._model.parameters():
# Perform allreduce on the pseudogradients
assert p.grad is not None
if p.grad is None:
continue
work = self._manager.allreduce(p.grad)
works.append(work)
# Wait for all allreduce operations to complete

for work in works:
work.wait()

def bucketize_and_allreduce(
self,
tensors: List[torch.Tensor],
allreduce_fn: Callable[[torch.Tensor], Any],
bucket_size_bytes: int,
) -> None:
"""
Applies allreduce on a list of tensors using bucketization.

Args:
tensors: List of torch tensors (e.g., gradients).
allreduce_fn: Function that takes a tensor and performs allreduce (e.g., manager.allreduce).
bucket_size_bytes: Max size of each bucket in bytes.
"""
if not tensors:
return

total_size = sum(t.numel() for t in tensors)
dtype, device = tensors[0].dtype, tensors[0].device

offset = 0
flat_index = 0
while offset < total_size:
chunk_size = min(
bucket_size_bytes // tensors[0].element_size(), total_size - offset
)
flat_buffer = torch.zeros(chunk_size, dtype=dtype, device=device)

pack_offset, bucket_tensors = 0, []
for t in tensors[flat_index:]:
numel = t.numel()
if pack_offset + numel > chunk_size:
break
flat_buffer[pack_offset : pack_offset + numel].copy_(t.view(-1))
bucket_tensors.append((t, pack_offset, numel))
pack_offset += numel
flat_index += 1

work = allreduce_fn(flat_buffer)
work.wait()

for t, pack_offset, numel in bucket_tensors:
t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t))

offset += chunk_size

def _allreduce_bucketized(self) -> None:
"""
Averages gradients using bucketized allreduce with a fixed buffer.
"""
grads = [p.grad for p in self._model.parameters() if p.grad is not None]
self.bucketize_and_allreduce(
grads,
allreduce_fn=self._manager.allreduce,
bucket_size_bytes=self.bucket_cap_mb,
)
93 changes: 92 additions & 1 deletion torchft/local_sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

from typing import Dict
from unittest import TestCase
from unittest.mock import create_autospec
from unittest.mock import MagicMock, create_autospec

import torch
from parameterized import parameterized
from torch import nn, optim

from torchft.local_sgd import DiLoCo, LocalSGD
Expand Down Expand Up @@ -129,3 +130,93 @@ def test_diloco_healthy(self) -> None:

outer_opt_state = outer_optimizer.state_dict()
self.assertEqual(len(outer_opt_state["state"]), parameter_count)

@parameterized.expand(
[
("bucketized_should_use_fewer_calls", True, True),
("non_bucketized_should_call_per_param", False, False),
]
)
def test_diloco_allreduce_call_efficiency(
self, name, use_bucketization, expect_fewer_calls
):
model = SimpleModel()

inner_optimizer = torch.optim.AdamW(
model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
)
outer_optimizer = torch.optim.SGD(
model.parameters(), lr=0.7, momentum=0.9, nesterov=True
)

manager = create_autospec(Manager)
manager._use_async_quorum = False
manager.should_commit.return_value = True

with DiLoCo(
manager,
model,
inner_optimizer,
outer_optimizer,
sync_every=2,
use_bucketization=use_bucketization,
) as diloco:
inp = torch.rand(2, 3)
loss = model(inp).mean()
loss.backward()
inner_optimizer.step()

loss = model(inp).mean()
loss.backward()
inner_optimizer.step()

allreduce_calls = manager.allreduce.call_count
param_count = len([p for p in model.parameters() if p.requires_grad])

if expect_fewer_calls:
self.assertLess(allreduce_calls, param_count)
else:
self.assertEqual(allreduce_calls, param_count)

def test_bucketization_correctness(self):
class TinyModel(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.tensor([1.0, 2.0]))
self.w2 = nn.Parameter(torch.tensor([3.0, 4.0, 5.0]))

def forward(self, x):
return x @ self.w1.unsqueeze(0).T + self.w2.sum()

model = TinyModel()
inner_opt = torch.optim.SGD(model.parameters(), lr=0.1)
outer_opt = torch.optim.SGD(model.parameters(), lr=0.1)

# Manually assign fake gradients
grads = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0, 5.0])]
for p, g in zip(model.parameters(), grads):
p.grad = g.clone()

manager = create_autospec(Manager)
manager._use_async_quorum = False
manager.should_commit.return_value = True

# Define fake allreduce: multiplies buffer by 2
def fake_allreduce(tensor):
tensor.mul_(2)
return MagicMock(wait=lambda: None)

manager.allreduce.side_effect = fake_allreduce

diloco = DiLoCo(
manager, model, inner_opt, outer_opt, sync_every=2, use_bucketization=True
)
diloco.bucket_cap_mb = 10 * 1024 * 1024

# Run only bucketized logic
diloco._average_grads()

# Expect grads to have been doubled
expected_grads = [g * 2 for g in grads]
for param, expected in zip(model.parameters(), expected_grads):
torch.testing.assert_close(param.grad, expected, rtol=1e-5, atol=1e-8)