Skip to content
This repository was archived by the owner on Dec 20, 2024. It is now read-only.

Commit a847f1a

Browse files
committed
Merge branch 'develop' into feature/transformer_sequence_sharding
2 parents e727e3c + fd2bcf1 commit a847f1a

20 files changed

+319
-77
lines changed

Diff for: .pre-commit-config.yaml

+3-8
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ repos:
2727
- id: python-check-blanket-noqa # Check for # noqa: all
2828
- id: python-no-log-warn # Check for log.warn
2929
- repo: https://github.com/psf/black-pre-commit-mirror
30-
rev: 24.8.0
30+
rev: 24.10.0
3131
hooks:
3232
- id: black
3333
args: [--line-length=120]
@@ -40,7 +40,7 @@ repos:
4040
- --force-single-line-imports
4141
- --profile black
4242
- repo: https://github.com/astral-sh/ruff-pre-commit
43-
rev: v0.6.9
43+
rev: v0.7.2
4444
hooks:
4545
- id: ruff
4646
args:
@@ -59,13 +59,8 @@ repos:
5959
hooks:
6060
- id: rstfmt
6161
exclude: 'cli/.*' # Because we use argparse
62-
- repo: https://github.com/b8raoult/pre-commit-docconvert
63-
rev: "0.1.5"
64-
hooks:
65-
- id: docconvert
66-
args: ["numpy"]
6762
- repo: https://github.com/tox-dev/pyproject-fmt
68-
rev: "2.2.4"
63+
rev: "v2.5.0"
6964
hooks:
7065
- id: pyproject-fmt
7166
- repo: https://github.com/jshwi/docsig # Check docstrings against function sig

Diff for: CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ Keep it human-readable, your future self will thank you!
2222
- configurabilty of the dropout probability in the the MultiHeadSelfAttention module
2323
- Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13)
2424
- GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46)
25+
- Mask NaN values in training loss function [#271](https://github.com/ecmwf-lab/aifs-mono/issues/271)
26+
- New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64)
2527
- Contributors file [#69](https://github.com/ecmwf/anemoi-models/pull/69)
2628

2729
### Changed
@@ -33,6 +35,7 @@ Keep it human-readable, your future self will thank you!
3335
- ci: extened python versions to include 3.11 and 3.12 [#66](https://github.com/ecmwf/anemoi-models/pull/66)
3436
- Update copyright notice
3537
- Fix `__version__` import in init
38+
- Fix missing copyrights [#71](https://github.com/ecmwf/anemoi-models/pull/71)
3639

3740
### Removed
3841

Diff for: docs/conf.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@
2929

3030
project = "Anemoi Models"
3131

32-
author = "ECMWF"
32+
author = "Anemoi contributors"
3333

3434
year = datetime.datetime.now().year
3535
if year == 2024:
3636
years = "2024"
3737
else:
3838
years = "2024-%s" % (year,)
3939

40-
copyright = "%s, ECMWF" % (years,)
40+
copyright = "%s, Anemoi contributors" % (years,)
4141

4242
try:
4343
from anemoi.models._version import __version__

Diff for: pyproject.toml

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
# (C) Copyright 2024 ECMWF.
1+
# (C) Copyright 2024 Anemoi contributors.
22
#
33
# This software is licensed under the terms of the Apache Licence Version 2.0
44
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
56
# In applying this licence, ECMWF does not waive the privileges and immunities
67
# granted to it by virtue of its status as an intergovernmental organisation
78
# nor does it submit to any jurisdiction.
89

9-
# https://packaging.python.org/en/latest/guides/writing-pyproject-toml/
10-
1110
[build-system]
1211
build-backend = "setuptools.build_meta"
1312

@@ -36,6 +35,7 @@ classifiers = [
3635
"Programming Language :: Python :: 3.10",
3736
"Programming Language :: Python :: 3.11",
3837
"Programming Language :: Python :: 3.12",
38+
"Programming Language :: Python :: 3.13",
3939
"Programming Language :: Python :: Implementation :: CPython",
4040
"Programming Language :: Python :: Implementation :: PyPy",
4141
]

Diff for: src/anemoi/models/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
23
# This software is licensed under the terms of the Apache Licence Version 2.0
34
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
46
# In applying this licence, ECMWF does not waive the privileges and immunities
57
# granted to it by virtue of its status as an intergovernmental organisation
68
# nor does it submit to any jurisdiction.

Diff for: src/anemoi/models/__main__.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
#!/usr/bin/env python
2-
# (C) Copyright 2024 ECMWF.
1+
# (C) Copyright 2024 Anemoi contributors.
32
#
43
# This software is licensed under the terms of the Apache Licence Version 2.0
54
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
66
# In applying this licence, ECMWF does not waive the privileges and immunities
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
9-
#
109

1110
from anemoi.utils.cli import cli_main
1211
from anemoi.utils.cli import make_parser

Diff for: src/anemoi/models/commands/__init__.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
#!/usr/bin/env python
2-
# (C) Copyright 2024 ECMWF.
1+
# (C) Copyright 2024 Anemoi contributors.
32
#
43
# This software is licensed under the terms of the Apache Licence Version 2.0
54
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
66
# In applying this licence, ECMWF does not waive the privileges and immunities
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
9-
#
109

1110
import os
1211

Diff for: src/anemoi/models/data_indices/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.

Diff for: src/anemoi/models/distributed/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.

Diff for: src/anemoi/models/interface/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
# (C) Copyright 2024 ECMWF.
1+
# (C) Copyright 2024 Anemoi contributors.
22
#
33
# This software is licensed under the terms of the Apache Licence Version 2.0
44
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
56
# In applying this licence, ECMWF does not waive the privileges and immunities
67
# granted to it by virtue of its status as an intergovernmental organisation
78
# nor does it submit to any jurisdiction.
8-
#
99

1010
import uuid
1111

Diff for: src/anemoi/models/layers/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.

Diff for: src/anemoi/models/layers/graph.py

+71-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
from torch import Tensor
1414
from torch import nn
15+
from torch_geometric.data import HeteroData
1516

1617

1718
class TrainableTensor(nn.Module):
@@ -36,8 +37,77 @@ def __init__(self, tensor_size: int, trainable_size: int) -> None:
3637
def forward(self, x: Tensor, batch_size: int) -> Tensor:
3738
latent = [einops.repeat(x, "e f -> (repeat e) f", repeat=batch_size)]
3839
if self.trainable is not None:
39-
latent.append(einops.repeat(self.trainable, "e f -> (repeat e) f", repeat=batch_size))
40+
latent.append(einops.repeat(self.trainable.to(x.device), "e f -> (repeat e) f", repeat=batch_size))
4041
return torch.cat(
4142
latent,
4243
dim=-1, # feature dimension
4344
)
45+
46+
47+
class NamedNodesAttributes(nn.Module):
48+
"""Named Nodes Attributes information.
49+
50+
Attributes
51+
----------
52+
num_nodes : dict[str, int]
53+
Number of nodes for each group of nodes.
54+
attr_ndims : dict[str, int]
55+
Total dimension of node attributes (non-trainable + trainable) for each group of nodes.
56+
trainable_tensors : nn.ModuleDict
57+
Dictionary of trainable tensors for each group of nodes.
58+
59+
Methods
60+
-------
61+
get_coordinates(self, name: str) -> Tensor
62+
Get the coordinates of a set of nodes.
63+
forward( self, name: str, batch_size: int) -> Tensor
64+
Get the node attributes to be passed trough the graph neural network.
65+
"""
66+
67+
num_nodes: dict[str, int]
68+
attr_ndims: dict[str, int]
69+
trainable_tensors: dict[str, TrainableTensor]
70+
71+
def __init__(self, num_trainable_params: int, graph_data: HeteroData) -> None:
72+
"""Initialize NamedNodesAttributes."""
73+
super().__init__()
74+
75+
self.define_fixed_attributes(graph_data, num_trainable_params)
76+
77+
self.trainable_tensors = nn.ModuleDict()
78+
for nodes_name, nodes in graph_data.node_items():
79+
self.register_coordinates(nodes_name, nodes.x)
80+
self.register_tensor(nodes_name, num_trainable_params)
81+
82+
def define_fixed_attributes(self, graph_data: HeteroData, num_trainable_params: int) -> None:
83+
"""Define fixed attributes."""
84+
nodes_names = list(graph_data.node_types)
85+
self.num_nodes = {nodes_name: graph_data[nodes_name].num_nodes for nodes_name in nodes_names}
86+
self.attr_ndims = {
87+
nodes_name: 2 * graph_data[nodes_name].x.shape[1] + num_trainable_params for nodes_name in nodes_names
88+
}
89+
90+
def register_coordinates(self, name: str, node_coords: Tensor) -> None:
91+
"""Register coordinates."""
92+
sin_cos_coords = torch.cat([torch.sin(node_coords), torch.cos(node_coords)], dim=-1)
93+
self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True)
94+
95+
def get_coordinates(self, name: str) -> Tensor:
96+
"""Return original coordinates."""
97+
sin_cos_coords = getattr(self, f"latlons_{name}")
98+
ndim = sin_cos_coords.shape[1] // 2
99+
sin_values = sin_cos_coords[:, :ndim]
100+
cos_values = sin_cos_coords[:, ndim:]
101+
return torch.atan2(sin_values, cos_values)
102+
103+
def register_tensor(self, name: str, num_trainable_params: int) -> None:
104+
"""Register a trainable tensor."""
105+
self.trainable_tensors[name] = TrainableTensor(self.num_nodes[name], num_trainable_params)
106+
107+
def forward(self, name: str, batch_size: int) -> Tensor:
108+
"""Returns the node attributes to be passed trough the graph neural network.
109+
110+
It includes both the coordinates and the trainable parameters.
111+
"""
112+
latlons = getattr(self, f"latlons_{name}")
113+
return self.trainable_tensors[name](latlons, batch_size)

Diff for: src/anemoi/models/models/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.

Diff for: src/anemoi/models/models/encoder_processor_decoder.py

+14-51
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch_geometric.data import HeteroData
2323

2424
from anemoi.models.distributed.shapes import get_shape_shards
25-
from anemoi.models.layers.graph import TrainableTensor
25+
from anemoi.models.layers.graph import NamedNodesAttributes
2626

2727
LOGGER = logging.getLogger(__name__)
2828

@@ -56,42 +56,33 @@ def __init__(
5656

5757
self._calculate_shapes_and_indices(data_indices)
5858
self._assert_matching_indices(data_indices)
59-
60-
self.multi_step = model_config.training.multistep_input
61-
62-
self._define_tensor_sizes(model_config)
63-
64-
# Create trainable tensors
65-
self._create_trainable_attributes()
66-
67-
# Register lat/lon of nodes
68-
self._register_latlon("data", self._graph_name_data)
69-
self._register_latlon("hidden", self._graph_name_hidden)
70-
7159
self.data_indices = data_indices
7260

61+
self.multi_step = model_config.training.multistep_input
7362
self.num_channels = model_config.model.num_channels
7463

75-
input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size
64+
self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data)
65+
66+
input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data]
7667

7768
# Encoder data -> hidden
7869
self.encoder = instantiate(
7970
model_config.model.encoder,
8071
in_channels_src=input_dim,
81-
in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size,
72+
in_channels_dst=self.node_attributes.attr_ndims[self._graph_name_hidden],
8273
hidden_dim=self.num_channels,
8374
sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)],
84-
src_grid_size=self._data_grid_size,
85-
dst_grid_size=self._hidden_grid_size,
75+
src_grid_size=self.node_attributes.num_nodes[self._graph_name_data],
76+
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
8677
)
8778

8879
# Processor hidden -> hidden
8980
self.processor = instantiate(
9081
model_config.model.processor,
9182
num_channels=self.num_channels,
9283
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)],
93-
src_grid_size=self._hidden_grid_size,
94-
dst_grid_size=self._hidden_grid_size,
84+
src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
85+
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
9586
)
9687

9788
# Decoder hidden -> data
@@ -102,8 +93,8 @@ def __init__(
10293
hidden_dim=self.num_channels,
10394
out_channels_dst=self.num_output_channels,
10495
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)],
105-
src_grid_size=self._hidden_grid_size,
106-
dst_grid_size=self._data_grid_size,
96+
src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
97+
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data],
10798
)
10899

109100
# Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite)
@@ -133,34 +124,6 @@ def _assert_matching_indices(self, data_indices: dict) -> None:
133124
self._internal_output_idx,
134125
), f"Internal model indices must match {self._internal_input_idx} != {self._internal_output_idx}"
135126

136-
def _define_tensor_sizes(self, config: DotDict) -> None:
137-
self._data_grid_size = self._graph_data[self._graph_name_data].num_nodes
138-
self._hidden_grid_size = self._graph_data[self._graph_name_hidden].num_nodes
139-
140-
self.trainable_data_size = config.model.trainable_parameters.data
141-
self.trainable_hidden_size = config.model.trainable_parameters.hidden
142-
143-
def _register_latlon(self, name: str, nodes: str) -> None:
144-
"""Register lat/lon buffers.
145-
146-
Parameters
147-
----------
148-
name : str
149-
Name to store the lat-lon coordinates of the nodes.
150-
nodes : str
151-
Name of nodes to map
152-
"""
153-
coords = self._graph_data[nodes].x
154-
sin_cos_coords = torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
155-
self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True)
156-
157-
def _create_trainable_attributes(self) -> None:
158-
"""Create all trainable attributes."""
159-
self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size)
160-
self.trainable_hidden = TrainableTensor(
161-
trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_size
162-
)
163-
164127
def _run_mapper(
165128
self,
166129
mapper: nn.Module,
@@ -210,12 +173,12 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
210173
x_data_latent = torch.cat(
211174
(
212175
einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"),
213-
self.trainable_data(self.latlons_data, batch_size=batch_size),
176+
self.node_attributes(self._graph_name_data, batch_size=batch_size),
214177
),
215178
dim=-1, # feature dimension
216179
)
217180

218-
x_hidden_latent = self.trainable_hidden(self.latlons_hidden, batch_size=batch_size)
181+
x_hidden_latent = self.node_attributes(self._graph_name_hidden, batch_size=batch_size)
219182

220183
# get shard shapes
221184
shard_shapes_data = get_shape_shards(x_data_latent, 0, model_comm_group)

0 commit comments

Comments
 (0)