Skip to content

Commit 6d25cc8

Browse files
committed
Updates based on comments
1) Fixed lint issues 2) Changed variable name for bucket size 3) Added parameterised unit test
1 parent 2de4bdf commit 6d25cc8

File tree

2 files changed

+31
-32
lines changed

2 files changed

+31
-32
lines changed

torchft/local_sgd.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ class DiLoCo(LocalSGD):
183183
diloco: https://arxiv.org/pdf/2311.08105
184184
"""
185185

186-
BUCKET_SIZE_BYTES = 32 * 1024 * 1024
186+
bucket_cap_mb = 32 * 1024 * 1024
187+
use_bucketization = False
187188

188189
def __init__(
189190
self,
@@ -194,7 +195,8 @@ def __init__(
194195
sync_every: int,
195196
backup_device: Optional[torch.device] = None,
196197
pin_memory: bool = True,
197-
use_bucketization=False,
198+
use_bucketization: bool = False,
199+
bucket_cap_mb: int = None,
198200
) -> None:
199201
if manager._use_async_quorum:
200202
raise ValueError(
@@ -205,6 +207,10 @@ def __init__(
205207
manager, model, inner_optimizer, sync_every, backup_device, pin_memory
206208
)
207209
self._outer_optimizer = outer_optimizer
210+
if bucket_cap_mb is not None:
211+
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)
212+
213+
self.use_bucketization = use_bucketization
208214

209215
def _perform_sync(self) -> None:
210216
"""

torchft/local_sgd_test.py

+23-30
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from unittest.mock import create_autospec
1010

1111
import torch
12+
from parameterized import parameterized
1213
from torch import nn, optim
1314

1415
from torchft.local_sgd import DiLoCo, LocalSGD
@@ -145,42 +146,33 @@ def test_diloco_healthy(self) -> None:
145146
outer_opt_state = outer_optimizer.state_dict()
146147
self.assertEqual(len(outer_opt_state["state"]), parameter_count)
147148

148-
def test_diloco_without_bucketization(self):
149+
@parameterized.expand(
150+
[
151+
(
152+
"without_bucketization",
153+
False,
154+
lambda self, manager, model: self.assertEqual(
155+
manager.allreduce.call_count, len(list(model.parameters()))
156+
),
157+
),
158+
(
159+
"with_bucketization",
160+
True,
161+
lambda self, manager, model: self.assertGreaterEqual(
162+
manager.allreduce.call_count, 1
163+
),
164+
),
165+
]
166+
)
167+
def test_diloco_all_reduce(self, name, use_bucketization, assert_func):
149168
model = SimpleModel()
150169
inner_optimizer = optim.AdamW(
151170
model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
152171
)
153172
outer_optimizer = optim.SGD(
154173
model.parameters(), lr=0.7, momentum=0.9, nesterov=True
155174
)
156-
manager = create_autospec(Manager)
157-
manager._use_async_quorum = False
158-
159-
with DiLoCo(
160-
manager,
161-
model,
162-
inner_optimizer,
163-
outer_optimizer,
164-
sync_every=2,
165-
use_bucketization=False,
166-
) as diloco:
167-
inp = torch.rand(2, 3)
168-
loss = model(inp).mean()
169-
loss.backward()
170-
inner_optimizer.step()
171-
self.assertEqual(diloco._local_step, 1)
172-
self.assertEqual(
173-
manager.allreduce.call_count, len(list(model.parameters()))
174-
)
175175

176-
def test_diloco_with_bucketization(self):
177-
model = SimpleModel()
178-
inner_optimizer = optim.AdamW(
179-
model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
180-
)
181-
outer_optimizer = optim.SGD(
182-
model.parameters(), lr=0.7, momentum=0.9, nesterov=True
183-
)
184176
manager = create_autospec(Manager)
185177
manager._use_async_quorum = False
186178

@@ -190,11 +182,12 @@ def test_diloco_with_bucketization(self):
190182
inner_optimizer,
191183
outer_optimizer,
192184
sync_every=2,
193-
use_bucketization=True,
185+
use_bucketization=use_bucketization,
194186
) as diloco:
195187
inp = torch.rand(2, 3)
196188
loss = model(inp).mean()
197189
loss.backward()
198190
inner_optimizer.step()
191+
199192
self.assertEqual(diloco._local_step, 1)
200-
self.assertGreaterEqual(manager.allreduce.call_count, 1)
193+
assert_func(self, manager, model)

0 commit comments

Comments
 (0)