Skip to content

Commit 44f9eba

Browse files
committed
applied black linting
1 parent 76074ce commit 44f9eba

File tree

2 files changed

+43
-35
lines changed

2 files changed

+43
-35
lines changed

graphium/nn/architectures/global_architectures.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def __init__(
174174
def _parse_layers(self, layer_type, residual_type):
175175
# Parse the layer and residuals
176176
from graphium.utils.spaces import LAYERS_DICT, RESIDUALS_DICT
177+
177178
self.layer_class, self.layer_name = self._parse_class_from_dict(layer_type, LAYERS_DICT)
178179
self.residual_class, self.residual_name = self._parse_class_from_dict(residual_type, RESIDUALS_DICT)
179180

@@ -532,7 +533,9 @@ def __init__(
532533
if num_ensemble_2 is None:
533534
layer_kwargs["num_ensemble"] = num_ensemble
534535
else:
535-
assert num_ensemble_2 == num_ensemble, f"num_ensemble={num_ensemble} != num_ensemble_2={num_ensemble_2}"
536+
assert (
537+
num_ensemble_2 == num_ensemble
538+
), f"num_ensemble={num_ensemble} != num_ensemble_2={num_ensemble_2}"
536539

537540
super().__init__(
538541
in_dim=in_dim,
@@ -559,7 +562,6 @@ def __init__(
559562
self.reduction = reduction
560563
self.reduction_fn = self._parse_reduction(reduction)
561564

562-
563565
def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optional[Callable]:
564566
r"""
565567
Parse the reduction argument.
@@ -587,10 +589,10 @@ def _parse_reduction(self, reduction: Optional[Union[str, Callable]]) -> Optiona
587589
def _parse_layers(self, layer_type, residual_type):
588590
# Parse the layer and residuals
589591
from graphium.utils.spaces import ENSEMBLE_LAYERS_DICT, RESIDUALS_DICT
592+
590593
self.layer_class, self.layer_name = self._parse_class_from_dict(layer_type, ENSEMBLE_LAYERS_DICT)
591594
self.residual_class, self.residual_name = self._parse_class_from_dict(residual_type, RESIDUALS_DICT)
592595

593-
594596
def forward(self, h: torch.Tensor) -> torch.Tensor:
595597
r"""
596598
Apply the ensemble MLP on the input features, then reduce the output if specified.

graphium/nn/ensemble_layers.py

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@
1010

1111
from graphium.nn.base_layers import FCLayer, MLP
1212

13+
1314
class EnsembleLinear(nn.Module):
14-
def __init__(self,
15-
in_dim: int,
16-
out_dim: int,
17-
num_ensemble: int,
18-
bias: bool = True,
19-
init_fn: Optional[Callable] = None,
20-
):
15+
def __init__(
16+
self,
17+
in_dim: int,
18+
out_dim: int,
19+
num_ensemble: int,
20+
bias: bool = True,
21+
init_fn: Optional[Callable] = None,
22+
):
2123
r"""
2224
Multiple linear layers that are applied in parallel with batched matrix multiplication with `torch.matmul`.
2325
@@ -38,7 +40,7 @@ def __init__(self,
3840
if bias:
3941
self.bias = nn.Parameter(torch.Tensor(num_ensemble, 1, out_dim))
4042
else:
41-
self.register_parameter('bias', None)
43+
self.register_parameter("bias", None)
4244

4345
# Initialize parameters
4446
self.init_fn = init_fn if init_fn is not None else mupi.xavier_uniform_
@@ -79,6 +81,7 @@ def forward(self, h: torch.Tensor) -> torch.Tensor:
7981

8082
return h
8183

84+
8285
class EnsembleFCLayer(FCLayer):
8386
def __init__(
8487
self,
@@ -162,7 +165,9 @@ def __init__(
162165

163166
# Linear layer, or MuReadout layer
164167
if not is_readout_layer:
165-
self.linear = EnsembleLinear(in_dim, out_dim, num_ensemble=num_ensemble, bias=bias, init_fn=init_fn)
168+
self.linear = EnsembleLinear(
169+
in_dim, out_dim, num_ensemble=num_ensemble, bias=bias, init_fn=init_fn
170+
)
166171
else:
167172
self.linear = EnsembleMuReadoutGraphium(in_dim, out_dim, bias=bias)
168173

@@ -180,20 +185,23 @@ def __repr__(self):
180185
rep = rep[:-1] + f", num_ensemble={self.linear.weight.shape[0]})"
181186
return rep
182187

188+
183189
class EnsembleMuReadoutGraphium(EnsembleLinear):
184190
"""
185191
This layer implements an ensemble version of μP with a 1/width multiplier and a
186192
constant variance initialization for both weights and biases.
187193
"""
188-
def __init__(self,
189-
in_dim: int,
190-
out_dim: int,
191-
num_ensemble: int,
192-
bias: bool = True,
193-
init_fn: Optional[Callable] = None,
194-
readout_zero_init=False,
195-
output_mult=1.0
196-
):
194+
195+
def __init__(
196+
self,
197+
in_dim: int,
198+
out_dim: int,
199+
num_ensemble: int,
200+
bias: bool = True,
201+
init_fn: Optional[Callable] = None,
202+
readout_zero_init=False,
203+
output_mult=1.0,
204+
):
197205
self.output_mult = output_mult
198206
self.readout_zero_init = readout_zero_init
199207
self.base_width = in_dim
@@ -214,35 +222,35 @@ def reset_parameters(self) -> None:
214222
super().reset_parameters()
215223

216224
def width_mult(self):
217-
assert hasattr(self.weight, 'infshape'), (
218-
'Please call set_base_shapes(...). If using torch.nn.DataParallel, '
219-
'switch to distributed training with '
220-
'torch.nn.parallel.DistributedDataParallel instead'
225+
assert hasattr(self.weight, "infshape"), (
226+
"Please call set_base_shapes(...). If using torch.nn.DataParallel, "
227+
"switch to distributed training with "
228+
"torch.nn.parallel.DistributedDataParallel instead"
221229
)
222230
return self.weight.infshape.width_mult()
223231

224232
def _rescale_parameters(self):
225-
'''Rescale parameters to convert SP initialization to μP initialization.
233+
"""Rescale parameters to convert SP initialization to μP initialization.
226234
227235
Warning: This method is NOT idempotent and should be called only once
228236
unless you know what you are doing.
229-
'''
230-
if hasattr(self, '_has_rescaled_params') and self._has_rescaled_params:
237+
"""
238+
if hasattr(self, "_has_rescaled_params") and self._has_rescaled_params:
231239
raise RuntimeError(
232240
"`_rescale_parameters` has been called once before already. "
233241
"Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n"
234242
"If you called `set_base_shapes` on a model loaded from a checkpoint, "
235243
"or just want to re-set the base shapes of an existing model, "
236244
"make sure to set the flag `rescale_params=False`.\n"
237-
"To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call.")
245+
"To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call."
246+
)
238247
if self.bias is not None:
239-
self.bias.data *= self.width_mult()**0.5
240-
self.weight.data *= self.width_mult()**0.5
248+
self.bias.data *= self.width_mult() ** 0.5
249+
self.weight.data *= self.width_mult() ** 0.5
241250
self._has_rescaled_params = True
242251

243252
def forward(self, x):
244-
return super().forward(
245-
self.output_mult * x / self.width_mult())
253+
return super().forward(self.output_mult * x / self.width_mult())
246254

247255
@property
248256
def absolute_width(self):
@@ -361,7 +369,6 @@ def __init__(
361369
last_layer_is_readout=last_layer_is_readout,
362370
droppath_rate=droppath_rate,
363371
constant_droppath_rate=constant_droppath_rate,
364-
365372
)
366373

367374
self.reduction = self._parse_reduction(reduction)
@@ -415,4 +422,3 @@ def __repr__(self):
415422
"""
416423
rep = super().__repr__()
417424
rep = rep[:-1] + f", num_ensemble={self.layers[0].linear.weight.shape[0]})"
418-

0 commit comments

Comments
 (0)