Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(training,config)!: refactor node_weights #102

Draft
wants to merge 44 commits into
base: 7-pressure-level-scalings-only-applied-in-specific-circumstances
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
9cc30e8
update config file
JPXKQX Jan 29, 2025
9d18ec5
bring graph node attribute scaler
JPXKQX Jan 29, 2025
594c3f0
update
JPXKQX Jan 30, 2025
2e0c476
rename Loss
JPXKQX Jan 30, 2025
54289fc
refactor
JPXKQX Jan 31, 2025
e37927f
Merge branch '7-pressure-level-scalings-only-applied-in-specific-circ…
JPXKQX Jan 31, 2025
06c5ea7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2025
c9baca5
FunctionalLoss
JPXKQX Jan 31, 2025
e9d566d
Merge branch 'feature/move-extra-scalers' of https://github.com/ecmwf…
JPXKQX Jan 31, 2025
4b0810d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2025
9e7289e
rename
JPXKQX Jan 31, 2025
3a02e73
Merge branch 'feature/move-extra-scalers' of https://github.com/ecmwf…
JPXKQX Jan 31, 2025
06f7081
clean
JPXKQX Jan 31, 2025
7555ece
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2025
5a7a3a3
more
JPXKQX Jan 31, 2025
a0861a1
Merge branch 'feature/move-extra-scalers' of https://github.com/ecmwf…
JPXKQX Jan 31, 2025
e34d207
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2025
d6f2cab
add unit-norm for node_weights
JPXKQX Jan 31, 2025
4087814
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2025
1f3ba85
fix config
JPXKQX Jan 31, 2025
9ca3f02
Merge branch 'feature/move-extra-scalers' of https://github.com/ecmwf…
JPXKQX Jan 31, 2025
081691b
support output_mask applied to scalers over spatial dims
JPXKQX Jan 31, 2025
09a462d
refactor from BaseLoss to FunctionalLoss
JPXKQX Jan 31, 2025
b4c76b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2025
0d59731
rename
JPXKQX Jan 31, 2025
f16ba04
rrefactor LAM MSE
JPXKQX Jan 31, 2025
4c5e8d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2025
e93e81c
more rename
JPXKQX Jan 31, 2025
ae61573
working (WIP)
JPXKQX Jan 31, 2025
8319cf3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2025
98be990
nan_mask refactored (working)
JPXKQX Jan 31, 2025
9a2b7ed
Merge branch 'feature/move-extra-scalers' of https://github.com/ecmwf…
JPXKQX Jan 31, 2025
aaa126a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2025
7957885
update configs
JPXKQX Jan 31, 2025
3a56ce3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2025
b6b44fa
update config
JPXKQX Jan 31, 2025
995a417
Merge branch 'feature/move-extra-scalers' of https://github.com/ecmwf…
JPXKQX Jan 31, 2025
b08989a
improve lam_wmse
JPXKQX Feb 3, 2025
6f2d887
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 3, 2025
ad7d533
make scale_dims more explicit (metaclass)
JPXKQX Feb 3, 2025
af75bdd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 3, 2025
3d6d4d3
clean
JPXKQX Feb 3, 2025
4dbe3b5
optionals
JPXKQX Feb 3, 2025
871e3df
fix pre-commit
JPXKQX Feb 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions training/docs/modules/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
This module is used to define the loss function used to train the model.

Anemoi-training exposes a couple of loss functions by default to be
used, all of which are subclassed from ``BaseWeightedLoss``. This class
used, all of which are subclassed from ``BaseLoss``. This class
enables scaler multiplication, and graph node weighting.

.. automodule:: anemoi.training.losses.weightedloss
Expand Down Expand Up @@ -110,7 +110,7 @@ By default, only `all` is kept in the normalised space and scaled.
***********************

Additionally, you can define your own loss function by subclassing
``BaseWeightedLoss`` and implementing the ``forward`` method, or by
``BaseLoss`` and implementing the ``forward`` method, or by
subclassing ``FunctionalWeightedLoss`` and implementing the
``calculate_difference`` function. The latter abstracts the scaling, and
node weighting, and allows you to just specify the difference
Expand Down
4 changes: 2 additions & 2 deletions training/src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
---
defaults:
- scalers: scalers
- scalers: global

# resume or fork a training from a checkpoint last.ckpt or specified in hardware.files.warm_start
run_id: null
Expand Down Expand Up @@ -54,7 +54,7 @@ training_loss:
# A selection of available scalers are listed in training/scalers/scalers.yaml
# '*' is a valid entry to use all `scalers` given, if a scaler is to be excluded
# add `!scaler_name`, i.e. ['*', '!scaler_1'], and `scaler_1` will not be added.
scalers: ['pressure_level', 'general_variable', 'nan_mask_weights']
scalers: ['pressure_level', 'general_variable', 'nan_mask_weights', 'node_weights']
ignore_nans: False

loss_gradient_scaling: False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ variable_groups:

# Several scalers can be added here. In order to be applied their names must be included in the loss.
# scaler name must be included in `scalers` in the losses for this to be applied.
# All scalers needs a `scale_dim` argument representing the dimension/s on which it is applied
# -1 : channels dimmension (timesteps, variables, ...)
# -2 : grid dimension
builders:
general_variable:
# Variable groups definition for scaling by variable level.
Expand Down Expand Up @@ -49,10 +52,10 @@ builders:
_target_: anemoi.training.losses.scaling.variable_tendency.VarTendencyScaler
scale_dim: -1 # dimension on which scaling applied

# Scalers from node attributes
node_weights:
_target_: anemoi.training.losses.nodeweights.GraphNodeAttribute
target_nodes: ${graph.data}
node_attribute: area_weight
scale_dim: 2 # dimension on which scaling applied

# limited_area_mask
_target_: anemoi.training.losses.scaling.node_attributes.GraphNodeAttributeScaler
nodes_name: ${graph.data}
nodes_attribute_name: area_weight
apply_output_mask: True
scale_dim: -2 # dimension on which scaling applied
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
variable_groups:
default: sfc
pl: [q, t, u, v, w, z]

# Several scalers can be added here. In order to be applied their names must be included in the loss.
# scaler name must be included in `scalers` in the losses for this to be applied.
# All scalers needs a `scale_dim` argument representing the dimension/s on which it is applied
# -1 : channels dimmension (timesteps, variables, ...)
# -2 : grid dimension
builders:
general_variable:
# Variable groups definition for scaling by variable level.
# The variable level scaling methods are defined under additional_scalers
# A default group is required and is appended as prefix to the metric of all variables not assigned to a group.
_target_: anemoi.training.losses.scaling.variable.GeneralVariableLossScaler
scale_dim: -1 # dimension on which scaling applied
weights:
default: 1
q: 0.6 #1
t: 6 #1
u: 0.8 #0.5
v: 0.5 #0.33
w: 0.001
z: 12 #1
sp: 10
10u: 0.1
10v: 0.1
2d: 0.5
tp: 0.025
cp: 0.0025

pressure_level:
_target_: anemoi.training.losses.scaling.variable_level.ReluVariableLevelScaler
group: pl
y_intercept: 0.2
slope: 0.001
scale_dim: -1 # dimension on which scaling applied

# mask NaNs with zeros in the loss function
nan_mask_weights:
_target_: anemoi.training.losses.scaling.loss_weights_mask.NaNMaskScaler
scale_dim: (-2, -1) # dimension on which scaling applied

# tendency scalers
# scale the prognostic losses by the stdev of the variable tendencies (e.g. the 6-hourly differences of the data)
# useful if including slow vs fast evolving variables in the training (e.g. Land/Ocean vs Atmosphere)
# if using this option 'variable_loss_scalings' should all be set close to 1.0 for prognostic variables
stdev_tendency:
_target_: anemoi.training.losses.scaling.variable_tendency.StdevTendencyScaler
scale_dim: -1 # dimension on which scaling applied
var_tendency:
_target_: anemoi.training.losses.scaling.variable_tendency.VarTendencyScaler
scale_dim: -1 # dimension on which scaling applied

# Scalers from node attributes
node_weights:
_target_: anemoi.training.losses.scaling.node_attributes.GraphNodeAttributeScaler
nodes_name: ${graph.data}
nodes_attribute_name: area_weight
scale_dim: -2 # dimension on which scaling applied

limited_area_mask:
_target_: anemoi.training.losses.scaling.node_attributes.GraphNodeAttributeScaler
nodes_name: ${graph.data}
nodes_attribute_name: cutout_mask
scale_dim: -2
6 changes: 3 additions & 3 deletions training/src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from anemoi.training.diagnostics.plots import plot_loss
from anemoi.training.diagnostics.plots import plot_power_spectrum
from anemoi.training.diagnostics.plots import plot_predicted_multilevel_flat_sample
from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.losses.weightedloss import BaseLoss

if TYPE_CHECKING:
from typing import Any
Expand Down Expand Up @@ -855,9 +855,9 @@ def _plot(
)
self.parameter_names = [self.parameter_names[i] for i in argsort_indices]

if not isinstance(pl_module.loss, BaseWeightedLoss):
if not isinstance(pl_module.loss, BaseLoss):
LOGGER.warning(
"Loss function must be a subclass of BaseWeightedLoss, or provide `squash`.",
"Loss function must be a subclass of BaseLoss, or provide `squash`.",
RuntimeWarning,
)

Expand Down
4 changes: 2 additions & 2 deletions training/src/anemoi/training/losses/huber.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

import torch

from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.losses.weightedloss import BaseLoss

LOGGER = logging.getLogger(__name__)


class WeightedHuberLoss(BaseWeightedLoss):
class WeightedHuberLoss(BaseLoss):
"""Node-weighted Huber loss."""

name = "whuber"
Expand Down
4 changes: 2 additions & 2 deletions training/src/anemoi/training/losses/limitedarea.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

import torch

from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.losses.weightedloss import BaseLoss

LOGGER = logging.getLogger(__name__)


class WeightedMSELossLimitedArea(BaseWeightedLoss):
class WeightedMSELossLimitedArea(BaseLoss):
"""Node-weighted MSE loss, calculated only within or outside the limited area.

Further, the loss can be computed for the specified region (default),
Expand Down
13 changes: 6 additions & 7 deletions training/src/anemoi/training/losses/logcosh.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class WeightedLogCoshLoss(BaseWeightedLoss):

def __init__(
self,
node_weights: torch.Tensor,
ignore_nans: bool = False,
**kwargs,
) -> None:
Expand All @@ -56,11 +55,7 @@ def __init__(
Allow nans in the loss and apply methods ignoring nans for measuring the loss, by default False

"""
super().__init__(
node_weights=node_weights,
ignore_nans=ignore_nans,
**kwargs,
)
super().__init__(ignore_nans=ignore_nans, **kwargs)

def forward(
self,
Expand Down Expand Up @@ -94,4 +89,8 @@ def forward(
"""
out = LogCosh.apply(pred - target)
out = self.scale(out, scaler_indices, without_scalers=without_scalers)
return self.scale_by_node_weights(out, squash)

if squash:
out = self.avg_function(out, dim=-1)

return self.sum_function(out, dim=(0, 1, 2))
JPXKQX marked this conversation as resolved.
Show resolved Hide resolved
83 changes: 83 additions & 0 deletions training/src/anemoi/training/losses/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# (C) Copyright 2025- Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
from __future__ import annotations

import logging

import torch
from hydra.utils import instantiate
from omegaconf import DictConfig
from omegaconf import OmegaConf
from anemoi.training.losses.weightedloss import BaseLoss


LOGGER = logging.getLogger(__name__)


# Future import breaks other type hints TODO Harrison Cook
def get_loss_function(
config: DictConfig,
scalers: dict[str, tuple[int | tuple[int, ...] | torch.Tensor]] | None = None,
**kwargs
) -> BaseLoss | torch.nn.ModuleList:
"""Get loss functions from config.

Can be ModuleList if multiple losses are specified.

Parameters
----------
config : DictConfig
Loss function configuration, should include `scalers` if scalers are to be added to the loss function.
scalers : dict[str, tuple[int | tuple[int, ...] | torch.Tensor]], optional
Scalers which can be added to the loss function. Defaults to None., by default None
If a scaler is to be added to the loss, ensure it is in `scalers` in the loss config
E.g.
If `scalers: ['variable']` is set in the config, and `variable` in `scalers`
`variable` will be added to the scaler of the loss function.
kwargs : Any
Additional arguments to pass to the loss function

Returns
-------
Union[BaseLoss, torch.nn.ModuleList]
Loss function, or list of metrics

Raises
------
TypeError
If not a subclass of `BaseLoss`
ValueError
If scaler is not found in valid scalers
"""
config_container = OmegaConf.to_container(config, resolve=False)
if isinstance(config_container, list):
return torch.nn.ModuleList([
get_loss_function(OmegaConf.create(loss_config), scalers=scalers, **kwargs) for loss_config in config
])

loss_config = OmegaConf.to_container(config, resolve=True)
scalers_to_include = loss_config.pop("scalers", [])

if "*" in scalers_to_include:
scalers_to_include = [s for s in list(scalers.keys()) if f"!{s}" not in scalers_to_include]

# Instantiate the loss function with the loss_init_config
loss_function = instantiate(loss_config, **kwargs)

if not isinstance(loss_function, BaseLoss):
error_msg = f"Loss must be a subclass of 'BaseLoss', not {type(loss_function)}"
raise TypeError(error_msg)

for key in scalers_to_include:
if key not in scalers or []:
error_msg = f"Scaler {key!r} not found in valid scalers: {list(scalers.keys())}"
raise ValueError(error_msg)
loss_function.add_scaler(*scalers[key], name=key)

return loss_function
19 changes: 8 additions & 11 deletions training/src/anemoi/training/losses/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,18 @@

import torch

from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.losses.weightedloss import BaseLoss

LOGGER = logging.getLogger(__name__)


class WeightedMAELoss(BaseWeightedLoss):
class WeightedMAELoss(BaseLoss):
"""Node-weighted MAE loss."""

name = "wmae"

def __init__(
self,
node_weights: torch.Tensor,
ignore_nans: bool = False,
**kwargs,
) -> None:
Expand All @@ -36,17 +35,11 @@ def __init__(

Parameters
----------
node_weights : torch.Tensor of shape (N, )
Weight of each node in the loss function
ignore_nans : bool, optional
Allow nans in the loss and apply methods ignoring nans for measuring the loss, by default False

"""
super().__init__(
node_weights=node_weights,
ignore_nans=ignore_nans,
**kwargs,
)
super().__init__(ignore_nans=ignore_nans, **kwargs)

def forward(
self,
Expand Down Expand Up @@ -80,4 +73,8 @@ def forward(
"""
out = torch.abs(pred - target)
out = self.scale(out, scaler_indices, without_scalers=without_scalers)
return self.scale_by_node_weights(out, squash)

if squash:
out = self.avg_function(out, dim=-1)

return self.sum_function(out, dim=(0, 1, 2))
19 changes: 8 additions & 11 deletions training/src/anemoi/training/losses/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,30 @@

import torch

from anemoi.training.losses.weightedloss import BaseWeightedLoss
from anemoi.training.losses.weightedloss import BaseLoss

LOGGER = logging.getLogger(__name__)


class WeightedMSELoss(BaseWeightedLoss):
class WeightedMSELoss(BaseLoss):
"""Node-weighted MSE loss."""

name = "wmse"

def __init__(
self,
node_weights: torch.Tensor,
ignore_nans: bool = False,
**kwargs,
) -> None:
"""Node- and feature weighted MSE Loss.

Parameters
----------
node_weights : torch.Tensor of shape (N, )
Weight of each node in the loss function
ignore_nans : bool, optional
Allow nans in the loss and apply methods ignoring nans for measuring the loss, by default False

"""
super().__init__(
node_weights=node_weights,
ignore_nans=ignore_nans,
**kwargs,
)
super().__init__(ignore_nans=ignore_nans, **kwargs)

def forward(
self,
Expand Down Expand Up @@ -77,4 +70,8 @@ def forward(
"""
out = torch.square(pred - target)
out = self.scale(out, scaler_indices, without_scalers=without_scalers)
return self.scale_by_node_weights(out, squash)

if squash:
out = self.avg_function(out, dim=-1)

return self.sum_function(out, dim=(0, 1, 2))
Loading