Skip to content

Commit

Permalink
fix normalise() and edge builder
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Feb 6, 2025
1 parent 2231c19 commit 02e11b9
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 25 deletions.
11 changes: 6 additions & 5 deletions graphs/src/anemoi/graphs/edges/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ def message(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor:
if edge_features.ndim == 1:
edge_features = edge_features.unsqueeze(-1)

return self.normalise(edge_features)
return edge_features

def aggregate(self, edge_features: torch.Tensor) -> torch.Tensor:
return edge_features
return self.normalise(edge_features)


class BasePositionalBuilder(BaseEdgeAttributeBuilder, ABC):
Expand Down Expand Up @@ -150,9 +150,10 @@ def compute(self, x_i: torch.Tensor, _x_j: torch.Tensor) -> torch.Tensor:
return x_i

def forward(self, x: tuple[NodeStorage, NodeStorage], edge_index: Adj, size: Size = None) -> torch.Tensor:
return self.propagate(
edge_index, x=(x[self.node_idx][self.node_attr_name], x[self.node_idx][self.node_attr_name]), size=size
)
return self.propagate(edge_index, x=x, size=size)

def message(self, x_i: NodeStorage, x_j: NodeStorage) -> torch.Tensor:
return (x_i, x_j)[self.node_idx][self.node_attr_name]


class AttributeFromSourceNode(BaseAttributeFromNodeBuilder):
Expand Down
11 changes: 5 additions & 6 deletions graphs/src/anemoi/graphs/edges/directional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian

NORTH_POLE = torch.tensor([[0, 0, 1]], dtype=torch.float32) # North pole in 3D coordinates
NORTH_POLE = [0, 0, 1] # North pole in 3D coordinates


def direction_vec(points: torch.Tensor, reference: torch.Tensor, epsilon: float = 10e-11) -> torch.Tensor:
Expand Down Expand Up @@ -78,20 +78,19 @@ def compute_directions(source_coords: torch.Tensor, target_coords: torch.Tensor)
torch.Tensor of shape (N, 2)
The direction of the edge.
"""
north_pole = torch.tensor([NORTH_POLE], dtype=source_coords.dtype).to(device=source_coords.device)
source_coords_xyz = latlon_rad_to_cartesian(source_coords, 1.0)
target_coords_xyz = latlon_rad_to_cartesian(target_coords, 1.0)

# Compute the unit direction vector & the angle theta between target coords and the north pole.
v_unit = direction_vec(target_coords_xyz, NORTH_POLE.to(source_coords.device))
theta = torch.acos(
torch.clamp(torch.sum(target_coords_xyz * NORTH_POLE.to(source_coords.device), dim=1), -1.0, 1.0)
) # Clamp for numerical stability
v_unit = direction_vec(target_coords_xyz, north_pole)
theta = torch.acos(torch.clamp(torch.sum(target_coords_xyz * north_pole, dim=1), -1.0, 1.0)) # Clamp for numerical stability

# Rotate source coords by angle theta around v_unit axis.
rotated_source_coords_xyz = rotate_vectors(source_coords_xyz, v_unit, theta)

# Compute the direction from the rotated vector to the north pole.
direction = direction_vec(rotated_source_coords_xyz, NORTH_POLE.to(source_coords.device))
direction = direction_vec(rotated_source_coords_xyz, north_pole)
normed_direction = direction / torch.norm(direction, dim=1).unsqueeze(-1)

# All 3rd components should be 0s
Expand Down
4 changes: 2 additions & 2 deletions graphs/src/anemoi/graphs/nodes/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def post_process(self, values: np.ndarray) -> torch.Tensor:
if values.ndim == 1:
values = values[:, np.newaxis]

norm_values = self.normalise(values)
values = torch.tensor(values.astype(self.dtype))

return torch.tensor(norm_values.astype(self.dtype))
return self.normalise(values)

def compute(self, graph: HeteroData, nodes_name: str, **kwargs) -> torch.Tensor:
"""Get the nodes attribute.
Expand Down
14 changes: 7 additions & 7 deletions graphs/src/anemoi/graphs/normalise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@

import logging

import numpy as np
import torch

LOGGER = logging.getLogger(__name__)


class NormaliserMixin:
"""Mixin class for normalising attributes."""

def normalise(self, values: np.ndarray) -> np.ndarray:
def normalise(self, values: torch.Tensor) -> torch.Tensor:
"""Normalise the given values.
It supports different normalisation methods: None, 'l1',
Expand All @@ -37,15 +37,15 @@ def normalise(self, values: np.ndarray) -> np.ndarray:
LOGGER.debug(f"{self.__class__.__name__} values are not normalised.")
return values
if self.norm == "l1":
return values / np.sum(values)
return values / torch.sum(values)
if self.norm == "l2":
return values / np.linalg.norm(values)
return values / torch.norm(values)
if self.norm == "unit-max":
return values / np.amax(values)
return values / torch.amax(values)
if self.norm == "unit-range":
return (values - np.amin(values)) / (np.amax(values) - np.amin(values))
return (values - torch.amin(values)) / (torch.amax(values) - torch.amin(values))
if self.norm == "unit-std":
std = np.std(values)
std = torch.std(values)
if std == 0:
LOGGER.warning(f"Std. dev. of the {self.__class__.__name__} values is 0. Normalisation is skipped.")
return values
Expand Down
10 changes: 5 additions & 5 deletions graphs/tests/edges/test_edge_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def test_edge_attribute_from_node(attribute_builder_cls, graph_nodes_and_edges:
@pytest.mark.parametrize("attribute_builder", [EdgeDirection(), EdgeLength()])
def test_fail_edge_features(attribute_builder, graph_nodes_and_edges):
"""Test edge attribute builder fails with unknown nodes."""
with pytest.raises(AssertionError):
edge_index = graph_nodes_and_edges[TEST_EDGES].edge_index
source_nodes = graph_nodes_and_edges[TEST_EDGES[0]]
target_nodes = graph_nodes_and_edges[TEST_EDGES[2]]
attribute_builder(x=(source_nodes, target_nodes), edge_index=edge_index)
#with pytest.raises(AssertionError):
edge_index = graph_nodes_and_edges[TEST_EDGES].edge_index
source_nodes = graph_nodes_and_edges[TEST_EDGES[0]]
target_nodes = graph_nodes_and_edges[TEST_EDGES[2]]
attribute_builder(x=(source_nodes, target_nodes), edge_index=edge_index)

0 comments on commit 02e11b9

Please sign in to comment.