diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 0cdf7cef..5748def1 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -13,6 +13,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type import torch +import torch.distributed as dist from torch import nn, optim from torch.distributed.tensor import DTensor from torch.nn.parameter import Parameter @@ -166,6 +167,9 @@ class DiLoCo: DiLoCo paper: https://arxiv.org/pdf/2311.08105 """ + bucket_cap_mb: int = 32 * 1024 * 1024 + use_bucketization: bool = False + def __init__( self, manager: Manager, @@ -175,6 +179,8 @@ def __init__( sync_every: int, backup_device: Optional[torch.device] = None, pin_memory: bool = True, + use_bucketization: bool = False, + bucket_cap_mb: Optional[int] = None, ) -> None: """ Args: @@ -204,6 +210,12 @@ def __init__( self._hooks: List[RemovableHandle] = [] self._outer_optimizer = outer_optimizer + + if bucket_cap_mb is not None: + self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024) + + self.use_bucketization = use_bucketization + self.original_parameters: Dict[str, torch.Tensor] = {} for name, p in self._model.named_parameters(): if isinstance(p, DTensor): @@ -308,8 +320,17 @@ def _perform_sync(self) -> None: def _average_grads(self) -> None: """ - Average the gradients across the diloco group. + Efficiently averages gradients across the group using either: + - Per-parameter allreduce (old behavior) + - Bucketized allreduce (new behavior) """ + if self.use_bucketization: + self._allreduce_bucketized() + else: + self._allreduce_per_param() + + def _allreduce_per_param(self) -> None: + """Performs allreduce on each gradient tensor separately (original method).""" works = [] for p in self._model.parameters(): # Perform allreduce on the pseudogradients @@ -319,6 +340,60 @@ def _average_grads(self) -> None: else: work = self._manager.allreduce(p.grad) works.append(work) - # Wait for all allreduce operations to complete + for work in works: work.wait() + + def bucketize_and_allreduce( + self, + tensors: List[torch.Tensor], + bucket_size_bytes: int, + ) -> None: + """ + Applies allreduce on a list of tensors using bucketization. + + Args: + tensors: List of torch tensors (e.g., gradients). + bucket_size_bytes: Max size of each bucket in bytes. + """ + if not tensors: + return + + total_size = sum(t.numel() for t in tensors) + dtype, device = tensors[0].dtype, tensors[0].device + + offset = 0 + flat_index = 0 + while offset < total_size: + chunk_size = min( + bucket_size_bytes // tensors[0].element_size(), total_size - offset + ) + flat_buffer = torch.zeros(chunk_size, dtype=dtype, device=device) + + pack_offset, bucket_tensors = 0, [] + for t in tensors[flat_index:]: + numel = t.numel() + if pack_offset + numel > chunk_size: + break + flat_buffer[pack_offset : pack_offset + numel].copy_(t.view(-1)) + bucket_tensors.append((t, pack_offset, numel)) + pack_offset += numel + flat_index += 1 + + work = self._manager.allreduce(flat_buffer) + work.wait() + + for t, pack_offset, numel in bucket_tensors: + t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t)) + + offset += chunk_size + + def _allreduce_bucketized(self) -> None: + """ + Averages gradients using bucketized allreduce with a fixed buffer. + """ + grads = [p.grad for p in self._model.parameters() if p.grad is not None] + self.bucketize_and_allreduce( + grads, + bucket_size_bytes=self.bucket_cap_mb, + ) diff --git a/torchft/local_sgd_test.py b/torchft/local_sgd_test.py index 5c3e67b9..fc7a6857 100644 --- a/torchft/local_sgd_test.py +++ b/torchft/local_sgd_test.py @@ -9,7 +9,8 @@ from unittest.mock import MagicMock, create_autospec import torch -from torch import nn, optim +from parameterized import parameterized +from torch import Tensor, nn, optim from torch.distributed.tensor import DTensor from torchft.local_sgd import DiLoCo, LocalSGD, extract_local_tensor @@ -147,3 +148,96 @@ def test_diloco_healthy(self) -> None: outer_opt_state = outer_optimizer.state_dict() self.assertEqual(len(outer_opt_state["state"]), parameter_count) + + @parameterized.expand( + [ + ("bucketized_should_use_fewer_calls", True, True), + ("non_bucketized_should_call_per_param", False, False), + ] + ) + def test_diloco_allreduce_call_efficiency( + self, + name: str, + use_bucketization: bool, + expect_fewer_calls: bool, + ) -> None: + model = SimpleModel() + + inner_optimizer = torch.optim.AdamW( + model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95) + ) + outer_optimizer = torch.optim.SGD( + model.parameters(), lr=0.7, momentum=0.9, nesterov=True + ) + + manager = create_autospec(Manager) + manager._use_async_quorum = False + manager.should_commit.return_value = True + + with DiLoCo( + manager, + model, + inner_optimizer, + outer_optimizer, + sync_every=2, + use_bucketization=use_bucketization, + ) as diloco: + inp = torch.rand(2, 3) + loss = model(inp).mean() + loss.backward() + inner_optimizer.step() + + loss = model(inp).mean() + loss.backward() + inner_optimizer.step() + + allreduce_calls = manager.allreduce.call_count + param_count = len([p for p in model.parameters() if p.requires_grad]) + + if expect_fewer_calls: + self.assertLess(int(allreduce_calls), int(param_count)) + else: + self.assertEqual(int(allreduce_calls), int(param_count)) + + def test_bucketization_correctness(self) -> None: + class TinyModel(nn.Module): + def __init__(self): + super().__init__() + self.w1 = nn.Parameter(torch.tensor([1.0, 2.0])) + self.w2 = nn.Parameter(torch.tensor([3.0, 4.0, 5.0])) + + def forward(self, x): + return x @ self.w1.unsqueeze(0).T + self.w2.sum() + + model = TinyModel() + inner_opt = torch.optim.SGD(model.parameters(), lr=0.1) + outer_opt = torch.optim.SGD(model.parameters(), lr=0.1) + + # Manually assign fake gradients + grads = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0, 5.0])] + for p, g in zip(model.parameters(), grads): + p.grad = g.clone() + + manager = create_autospec(Manager) + manager._use_async_quorum = False + manager.should_commit.return_value = True + + # Define fake allreduce: multiplies buffer by 2 + def fake_allreduce(tensor: Tensor) -> MagicMock: + tensor.mul_(2) + return MagicMock(wait=lambda: None) + + manager.allreduce.side_effect = fake_allreduce + + diloco = DiLoCo( + manager, model, inner_opt, outer_opt, sync_every=2, use_bucketization=True + ) + diloco.bucket_cap_mb = 10 * 1024 * 1024 + + # Run only bucketized logic + diloco._average_grads() + + # Expect grads to have been doubled + expected_grads = [g * 2 for g in grads] + for param, expected in zip(model.parameters(), expected_grads): + torch.testing.assert_close(param.grad, expected, rtol=1e-5, atol=1e-8)