10
10
11
11
from graphium .nn .base_layers import FCLayer , MLP
12
12
13
+
13
14
class EnsembleLinear (nn .Module ):
14
- def __init__ (self ,
15
- in_dim : int ,
16
- out_dim : int ,
17
- num_ensemble : int ,
18
- bias : bool = True ,
19
- init_fn : Optional [Callable ] = None ,
20
- ):
15
+ def __init__ (
16
+ self ,
17
+ in_dim : int ,
18
+ out_dim : int ,
19
+ num_ensemble : int ,
20
+ bias : bool = True ,
21
+ init_fn : Optional [Callable ] = None ,
22
+ ):
21
23
r"""
22
24
Multiple linear layers that are applied in parallel with batched matrix multiplication with `torch.matmul`.
23
25
@@ -38,7 +40,7 @@ def __init__(self,
38
40
if bias :
39
41
self .bias = nn .Parameter (torch .Tensor (num_ensemble , 1 , out_dim ))
40
42
else :
41
- self .register_parameter (' bias' , None )
43
+ self .register_parameter (" bias" , None )
42
44
43
45
# Initialize parameters
44
46
self .init_fn = init_fn if init_fn is not None else mupi .xavier_uniform_
@@ -79,6 +81,7 @@ def forward(self, h: torch.Tensor) -> torch.Tensor:
79
81
80
82
return h
81
83
84
+
82
85
class EnsembleFCLayer (FCLayer ):
83
86
def __init__ (
84
87
self ,
@@ -162,7 +165,9 @@ def __init__(
162
165
163
166
# Linear layer, or MuReadout layer
164
167
if not is_readout_layer :
165
- self .linear = EnsembleLinear (in_dim , out_dim , num_ensemble = num_ensemble , bias = bias , init_fn = init_fn )
168
+ self .linear = EnsembleLinear (
169
+ in_dim , out_dim , num_ensemble = num_ensemble , bias = bias , init_fn = init_fn
170
+ )
166
171
else :
167
172
self .linear = EnsembleMuReadoutGraphium (in_dim , out_dim , bias = bias )
168
173
@@ -180,20 +185,23 @@ def __repr__(self):
180
185
rep = rep [:- 1 ] + f", num_ensemble={ self .linear .weight .shape [0 ]} )"
181
186
return rep
182
187
188
+
183
189
class EnsembleMuReadoutGraphium (EnsembleLinear ):
184
190
"""
185
191
This layer implements an ensemble version of μP with a 1/width multiplier and a
186
192
constant variance initialization for both weights and biases.
187
193
"""
188
- def __init__ (self ,
189
- in_dim : int ,
190
- out_dim : int ,
191
- num_ensemble : int ,
192
- bias : bool = True ,
193
- init_fn : Optional [Callable ] = None ,
194
- readout_zero_init = False ,
195
- output_mult = 1.0
196
- ):
194
+
195
+ def __init__ (
196
+ self ,
197
+ in_dim : int ,
198
+ out_dim : int ,
199
+ num_ensemble : int ,
200
+ bias : bool = True ,
201
+ init_fn : Optional [Callable ] = None ,
202
+ readout_zero_init = False ,
203
+ output_mult = 1.0 ,
204
+ ):
197
205
self .output_mult = output_mult
198
206
self .readout_zero_init = readout_zero_init
199
207
self .base_width = in_dim
@@ -214,35 +222,35 @@ def reset_parameters(self) -> None:
214
222
super ().reset_parameters ()
215
223
216
224
def width_mult (self ):
217
- assert hasattr (self .weight , ' infshape' ), (
218
- ' Please call set_base_shapes(...). If using torch.nn.DataParallel, '
219
- ' switch to distributed training with '
220
- ' torch.nn.parallel.DistributedDataParallel instead'
225
+ assert hasattr (self .weight , " infshape" ), (
226
+ " Please call set_base_shapes(...). If using torch.nn.DataParallel, "
227
+ " switch to distributed training with "
228
+ " torch.nn.parallel.DistributedDataParallel instead"
221
229
)
222
230
return self .weight .infshape .width_mult ()
223
231
224
232
def _rescale_parameters (self ):
225
- ''' Rescale parameters to convert SP initialization to μP initialization.
233
+ """ Rescale parameters to convert SP initialization to μP initialization.
226
234
227
235
Warning: This method is NOT idempotent and should be called only once
228
236
unless you know what you are doing.
229
- '''
230
- if hasattr (self , ' _has_rescaled_params' ) and self ._has_rescaled_params :
237
+ """
238
+ if hasattr (self , " _has_rescaled_params" ) and self ._has_rescaled_params :
231
239
raise RuntimeError (
232
240
"`_rescale_parameters` has been called once before already. "
233
241
"Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n "
234
242
"If you called `set_base_shapes` on a model loaded from a checkpoint, "
235
243
"or just want to re-set the base shapes of an existing model, "
236
244
"make sure to set the flag `rescale_params=False`.\n "
237
- "To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call." )
245
+ "To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call."
246
+ )
238
247
if self .bias is not None :
239
- self .bias .data *= self .width_mult ()** 0.5
240
- self .weight .data *= self .width_mult ()** 0.5
248
+ self .bias .data *= self .width_mult () ** 0.5
249
+ self .weight .data *= self .width_mult () ** 0.5
241
250
self ._has_rescaled_params = True
242
251
243
252
def forward (self , x ):
244
- return super ().forward (
245
- self .output_mult * x / self .width_mult ())
253
+ return super ().forward (self .output_mult * x / self .width_mult ())
246
254
247
255
@property
248
256
def absolute_width (self ):
@@ -361,7 +369,6 @@ def __init__(
361
369
last_layer_is_readout = last_layer_is_readout ,
362
370
droppath_rate = droppath_rate ,
363
371
constant_droppath_rate = constant_droppath_rate ,
364
-
365
372
)
366
373
367
374
self .reduction = self ._parse_reduction (reduction )
@@ -415,4 +422,3 @@ def __repr__(self):
415
422
"""
416
423
rep = super ().__repr__ ()
417
424
rep = rep [:- 1 ] + f", num_ensemble={ self .layers [0 ].linear .weight .shape [0 ]} )"
418
-
0 commit comments