10
10
"""
11
11
import logging
12
12
from types import TracebackType
13
- from typing import Any , Callable , Dict , Iterator , List , Mapping , Optional , Type
13
+ from typing import Any , Callable , Dict , Iterator , List , Optional , Tuple , Type
14
14
15
15
import torch
16
16
from torch import nn , optim
@@ -59,8 +59,6 @@ def __init__(
59
59
model : nn .Module ,
60
60
optimizer : optim .Optimizer ,
61
61
sync_every : int ,
62
- backup_device : Optional [torch .device ] = None ,
63
- pin_memory : bool = True ,
64
62
) -> None :
65
63
"""
66
64
Args:
@@ -78,21 +76,8 @@ def __init__(
78
76
self ._local_step = 0
79
77
self ._sync_every = sync_every
80
78
assert sync_every >= 1 , "sync_every must be greater than or equal to 1"
81
- device = backup_device or torch .device ("cpu" )
82
- self ._backup_parameters : Dict [str , torch .Tensor ] = {}
83
- for name , p in self ._model .named_parameters ():
84
- t = torch .empty (* tuple (p .shape ), dtype = p .dtype , device = device )
85
- if (
86
- pin_memory
87
- and t .device == torch .device ("cpu" )
88
- and torch .cuda .is_available ()
89
- ):
90
- t = t .pin_memory ()
91
- self ._backup_parameters [name ] = t
92
79
93
80
self ._hooks : List [RemovableHandle ] = []
94
- # Need to copy the parameters to the host to be safe if we are on the first step.
95
- self ._save_parameters ()
96
81
97
82
def __enter__ (self ) -> "LocalSGD" :
98
83
# Add optimizer hook which increments the local step counter and syncs if necessary
@@ -108,30 +93,15 @@ def __exit__(
108
93
traceback : Optional [TracebackType ],
109
94
) -> bool :
110
95
# Handle any cleanup or error handling here
111
- if exc_type is not None :
112
- # If an exception occurred, restore parameters
113
- self ._restore_parameters ()
114
96
# Clean up hooks
115
97
for hook in self ._hooks :
116
98
hook .remove ()
117
99
self ._hooks .clear ()
118
100
119
101
return False # Propagate exceptions
120
102
121
- def _save_parameters (self ) -> None :
122
- with torch .no_grad ():
123
- # TODO: consider running copy on a separate stream
124
- for name , p in self ._model .named_parameters ():
125
- self ._backup_parameters [name ].copy_ (p .data , non_blocking = True )
126
-
127
- def _restore_parameters (self ) -> None :
128
- with torch .no_grad ():
129
- # TODO: consider running copy on a separate stream
130
- for name , p in self ._model .named_parameters ():
131
- p .data .copy_ (self ._backup_parameters [name ], non_blocking = False )
132
-
133
103
def _step_post_hook (
134
- self , _optim : optim .Optimizer , _args : List [ object ], _kwargs : Dict [str , object ]
104
+ self , _optim : optim .Optimizer , _args : Tuple [ Any , ... ], _kwargs : Dict [str , Any ]
135
105
) -> None :
136
106
"""
137
107
This hook is registered on the optimizer and is called after the optimizer step.
@@ -151,30 +121,31 @@ def sync(self) -> None:
151
121
def _perform_sync (self ) -> None :
152
122
"""
153
123
Performs the synchronization of the model weights across the manager.
154
- This method is intended to be overridden by subclasses to implement custom
155
- synchronization logic.
156
124
"""
157
- self ._average ()
125
+ averaged_parameters = self ._average ()
158
126
if self ._manager .should_commit ():
159
- self ._save_parameters ()
160
- else :
161
- # commit failed, restore from the backup parameters
162
- self ._restore_parameters ()
163
-
164
- def _average (self ) -> None :
165
- # TODO: do we need to broadcast buffers like DDP does?
127
+ # Update the model parameters with the averaged values
128
+ for param , avg_param in zip (self ._model .parameters (), averaged_parameters ):
129
+ param .data .copy_ (avg_param )
166
130
131
+ def _average (self ) -> list [torch .Tensor ]:
132
+ """
133
+ Averages the model parameters across the manager and returns the averaged parameters.
134
+ """
167
135
works = []
168
-
136
+ averaged_parameters = []
169
137
for p in self ._model .parameters ():
170
- # TODO: bucketize parameters
171
- works .append (self ._manager .allreduce (p .data .detach ()))
172
-
138
+ # Create a new tensor to store the averaged parameter
139
+ p .data .grad = None
140
+ avg_param = p .data .clone ()
141
+ works .append (self ._manager .allreduce (avg_param ))
142
+ averaged_parameters .append (avg_param )
173
143
for work in works :
174
144
work .wait ()
145
+ return averaged_parameters
175
146
176
147
177
- class DiLoCo ( LocalSGD ) :
148
+ class DiLoCo :
178
149
"""
179
150
DiLoCo is a subclass of LocalSGD that overrides the synchronization
180
151
mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
@@ -197,27 +168,96 @@ def __init__(
197
168
"Using DiLoCo require synchronous quorum to be enabled. "
198
169
"Ensure that the manager is initialized with use_async_quorum=False"
199
170
)
200
- super ().__init__ (
201
- manager , model , inner_optimizer , sync_every , backup_device , pin_memory
202
- )
171
+ super ().__init__ ()
172
+ self ._manager = manager
173
+ self ._model = model
174
+ self ._local_optimizer = inner_optimizer
175
+ self ._local_step = 0
176
+ self ._sync_every = sync_every
177
+ assert sync_every >= 1 , "sync_every must be greater than or equal to 1"
178
+ self ._backup_device = backup_device
179
+ self ._pin_memory = pin_memory
180
+
181
+ self ._hooks : List [RemovableHandle ] = []
203
182
self ._outer_optimizer = outer_optimizer
183
+ self .original_parameters : Dict [str , torch .Tensor ] = {}
184
+ for name , p in self ._model .named_parameters ():
185
+ t = torch .empty (* tuple (p .shape ), dtype = p .dtype , device = self ._backup_device )
186
+ if (
187
+ self ._pin_memory
188
+ and t .device == torch .device ("cpu" )
189
+ and torch .cuda .is_available ()
190
+ ):
191
+ t = t .pin_memory ()
192
+ self .original_parameters [name ] = t
193
+
194
+ # Need to copy the parameters to the host to be safe if we are on the first step.
195
+ self ._save_parameters ()
196
+
197
+ def _save_parameters (self ) -> None :
198
+ with torch .no_grad ():
199
+ # TODO: consider running copy on a separate stream
200
+ for name , p in self ._model .named_parameters ():
201
+ self .original_parameters [name ].copy_ (p .data , non_blocking = True )
202
+
203
+ def _restore_parameters (self ) -> None :
204
+ with torch .no_grad ():
205
+ # TODO: consider running copy on a separate stream
206
+ for name , p in self ._model .named_parameters ():
207
+ p .data .copy_ (self .original_parameters [name ], non_blocking = False )
208
+
209
+ def __enter__ (self ) -> "DiLoCo" :
210
+ # Add optimizer hook which increments the local step counter and syncs if necessary
211
+ self ._hooks .append (
212
+ self ._local_optimizer .register_step_post_hook (self ._step_post_hook )
213
+ )
214
+ return self
215
+
216
+ def __exit__ (
217
+ self ,
218
+ exc_type : Optional [Type [BaseException ]],
219
+ exc_value : Optional [BaseException ],
220
+ traceback : Optional [TracebackType ],
221
+ ) -> bool :
222
+ # Handle any cleanup or error handling here
223
+ # Clean up hooks
224
+ for hook in self ._hooks :
225
+ hook .remove ()
226
+ self ._hooks .clear ()
227
+
228
+ return False # Propagate exceptions
229
+
230
+ def _step_post_hook (
231
+ self , _optim : optim .Optimizer , _args : Tuple [Any , ...], _kwargs : Dict [str , Any ]
232
+ ) -> None :
233
+ """
234
+ This hook is registered on the optimizer and is called after the optimizer step.
235
+ """
236
+ self ._local_step += 1
237
+ if self ._local_step >= self ._sync_every :
238
+ self .sync ()
239
+
240
+ def sync (self ) -> None :
241
+ """
242
+ Synchronizes and averages the model weights across the manager.
243
+ """
244
+ self ._manager .start_quorum ()
245
+ self ._perform_sync ()
246
+ self ._local_step = 0
204
247
205
248
def _perform_sync (self ) -> None :
206
249
"""
207
250
Overrides the sync method to calculate the pseugradient, average them across the manager group, and
208
251
step using the outer optimizer.
209
252
"""
210
-
211
253
# Set the .grad field of each parameter to its pseudogradient
212
254
for name , p in self ._model .named_parameters ():
213
- assert name in self ._backup_parameters
214
- pseudogradient = p .data - self ._backup_parameters [name ]
255
+ pseudogradient = p .data - self .original_parameters [name ]
215
256
p .grad = pseudogradient
216
257
217
258
self ._average_grads ()
218
259
# Restore the parameters back to the previous state
219
260
self ._restore_parameters ()
220
-
221
261
if self ._manager .should_commit ():
222
262
# Use the outer optimizer to update the model parameters
223
263
self ._outer_optimizer .step ()
0 commit comments