Skip to content

Commit 082753c

Browse files
authored
Clean up local_sgd and diloco (#120)
1 parent 6fe4c8e commit 082753c

File tree

4 files changed

+210
-126
lines changed

4 files changed

+210
-126
lines changed

torchft/local_sgd.py

+94-54
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"""
1111
import logging
1212
from types import TracebackType
13-
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Type
13+
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type
1414

1515
import torch
1616
from torch import nn, optim
@@ -59,8 +59,6 @@ def __init__(
5959
model: nn.Module,
6060
optimizer: optim.Optimizer,
6161
sync_every: int,
62-
backup_device: Optional[torch.device] = None,
63-
pin_memory: bool = True,
6462
) -> None:
6563
"""
6664
Args:
@@ -78,21 +76,8 @@ def __init__(
7876
self._local_step = 0
7977
self._sync_every = sync_every
8078
assert sync_every >= 1, "sync_every must be greater than or equal to 1"
81-
device = backup_device or torch.device("cpu")
82-
self._backup_parameters: Dict[str, torch.Tensor] = {}
83-
for name, p in self._model.named_parameters():
84-
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=device)
85-
if (
86-
pin_memory
87-
and t.device == torch.device("cpu")
88-
and torch.cuda.is_available()
89-
):
90-
t = t.pin_memory()
91-
self._backup_parameters[name] = t
9279

9380
self._hooks: List[RemovableHandle] = []
94-
# Need to copy the parameters to the host to be safe if we are on the first step.
95-
self._save_parameters()
9681

9782
def __enter__(self) -> "LocalSGD":
9883
# Add optimizer hook which increments the local step counter and syncs if necessary
@@ -108,30 +93,15 @@ def __exit__(
10893
traceback: Optional[TracebackType],
10994
) -> bool:
11095
# Handle any cleanup or error handling here
111-
if exc_type is not None:
112-
# If an exception occurred, restore parameters
113-
self._restore_parameters()
11496
# Clean up hooks
11597
for hook in self._hooks:
11698
hook.remove()
11799
self._hooks.clear()
118100

119101
return False # Propagate exceptions
120102

121-
def _save_parameters(self) -> None:
122-
with torch.no_grad():
123-
# TODO: consider running copy on a separate stream
124-
for name, p in self._model.named_parameters():
125-
self._backup_parameters[name].copy_(p.data, non_blocking=True)
126-
127-
def _restore_parameters(self) -> None:
128-
with torch.no_grad():
129-
# TODO: consider running copy on a separate stream
130-
for name, p in self._model.named_parameters():
131-
p.data.copy_(self._backup_parameters[name], non_blocking=False)
132-
133103
def _step_post_hook(
134-
self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object]
104+
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
135105
) -> None:
136106
"""
137107
This hook is registered on the optimizer and is called after the optimizer step.
@@ -151,30 +121,31 @@ def sync(self) -> None:
151121
def _perform_sync(self) -> None:
152122
"""
153123
Performs the synchronization of the model weights across the manager.
154-
This method is intended to be overridden by subclasses to implement custom
155-
synchronization logic.
156124
"""
157-
self._average()
125+
averaged_parameters = self._average()
158126
if self._manager.should_commit():
159-
self._save_parameters()
160-
else:
161-
# commit failed, restore from the backup parameters
162-
self._restore_parameters()
163-
164-
def _average(self) -> None:
165-
# TODO: do we need to broadcast buffers like DDP does?
127+
# Update the model parameters with the averaged values
128+
for param, avg_param in zip(self._model.parameters(), averaged_parameters):
129+
param.data.copy_(avg_param)
166130

131+
def _average(self) -> list[torch.Tensor]:
132+
"""
133+
Averages the model parameters across the manager and returns the averaged parameters.
134+
"""
167135
works = []
168-
136+
averaged_parameters = []
169137
for p in self._model.parameters():
170-
# TODO: bucketize parameters
171-
works.append(self._manager.allreduce(p.data.detach()))
172-
138+
# Create a new tensor to store the averaged parameter
139+
p.data.grad = None
140+
avg_param = p.data.clone()
141+
works.append(self._manager.allreduce(avg_param))
142+
averaged_parameters.append(avg_param)
173143
for work in works:
174144
work.wait()
145+
return averaged_parameters
175146

176147

177-
class DiLoCo(LocalSGD):
148+
class DiLoCo:
178149
"""
179150
DiLoCo is a subclass of LocalSGD that overrides the synchronization
180151
mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
@@ -197,27 +168,96 @@ def __init__(
197168
"Using DiLoCo require synchronous quorum to be enabled. "
198169
"Ensure that the manager is initialized with use_async_quorum=False"
199170
)
200-
super().__init__(
201-
manager, model, inner_optimizer, sync_every, backup_device, pin_memory
202-
)
171+
super().__init__()
172+
self._manager = manager
173+
self._model = model
174+
self._local_optimizer = inner_optimizer
175+
self._local_step = 0
176+
self._sync_every = sync_every
177+
assert sync_every >= 1, "sync_every must be greater than or equal to 1"
178+
self._backup_device = backup_device
179+
self._pin_memory = pin_memory
180+
181+
self._hooks: List[RemovableHandle] = []
203182
self._outer_optimizer = outer_optimizer
183+
self.original_parameters: Dict[str, torch.Tensor] = {}
184+
for name, p in self._model.named_parameters():
185+
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=self._backup_device)
186+
if (
187+
self._pin_memory
188+
and t.device == torch.device("cpu")
189+
and torch.cuda.is_available()
190+
):
191+
t = t.pin_memory()
192+
self.original_parameters[name] = t
193+
194+
# Need to copy the parameters to the host to be safe if we are on the first step.
195+
self._save_parameters()
196+
197+
def _save_parameters(self) -> None:
198+
with torch.no_grad():
199+
# TODO: consider running copy on a separate stream
200+
for name, p in self._model.named_parameters():
201+
self.original_parameters[name].copy_(p.data, non_blocking=True)
202+
203+
def _restore_parameters(self) -> None:
204+
with torch.no_grad():
205+
# TODO: consider running copy on a separate stream
206+
for name, p in self._model.named_parameters():
207+
p.data.copy_(self.original_parameters[name], non_blocking=False)
208+
209+
def __enter__(self) -> "DiLoCo":
210+
# Add optimizer hook which increments the local step counter and syncs if necessary
211+
self._hooks.append(
212+
self._local_optimizer.register_step_post_hook(self._step_post_hook)
213+
)
214+
return self
215+
216+
def __exit__(
217+
self,
218+
exc_type: Optional[Type[BaseException]],
219+
exc_value: Optional[BaseException],
220+
traceback: Optional[TracebackType],
221+
) -> bool:
222+
# Handle any cleanup or error handling here
223+
# Clean up hooks
224+
for hook in self._hooks:
225+
hook.remove()
226+
self._hooks.clear()
227+
228+
return False # Propagate exceptions
229+
230+
def _step_post_hook(
231+
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
232+
) -> None:
233+
"""
234+
This hook is registered on the optimizer and is called after the optimizer step.
235+
"""
236+
self._local_step += 1
237+
if self._local_step >= self._sync_every:
238+
self.sync()
239+
240+
def sync(self) -> None:
241+
"""
242+
Synchronizes and averages the model weights across the manager.
243+
"""
244+
self._manager.start_quorum()
245+
self._perform_sync()
246+
self._local_step = 0
204247

205248
def _perform_sync(self) -> None:
206249
"""
207250
Overrides the sync method to calculate the pseugradient, average them across the manager group, and
208251
step using the outer optimizer.
209252
"""
210-
211253
# Set the .grad field of each parameter to its pseudogradient
212254
for name, p in self._model.named_parameters():
213-
assert name in self._backup_parameters
214-
pseudogradient = p.data - self._backup_parameters[name]
255+
pseudogradient = p.data - self.original_parameters[name]
215256
p.grad = pseudogradient
216257

217258
self._average_grads()
218259
# Restore the parameters back to the previous state
219260
self._restore_parameters()
220-
221261
if self._manager.should_commit():
222262
# Use the outer optimizer to update the model parameters
223263
self._outer_optimizer.step()

0 commit comments

Comments
 (0)