Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Feb 5, 2025
1 parent 6fbfe4c commit 2d37526
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 23 deletions.
33 changes: 15 additions & 18 deletions graphs/src/anemoi/graphs/edges/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from __future__ import annotations

import logging
from abc import ABC
from abc import ABC, ABCMeta

import torch
from torch_geometric.data.storage import NodeStorage
Expand All @@ -26,25 +26,18 @@
LOGGER = logging.getLogger(__name__)


class NodeAttributeMeta(type):
def __new__(cls, name: str, bases: tuple, class_dict: dict):
if ABC in bases:
return super().__new__(cls, name, bases, class_dict)

if "node_attr_name" not in class_dict:
error_msg = f"Class {name} must define 'node_attr_name'"
raise TypeError(error_msg)

return super().__new__(cls, name, bases, class_dict)


class BaseEdgeAttributeBuilder(MessagePassing, NormaliserMixin, metaclass=NodeAttributeMeta):
class BaseEdgeAttributeBuilder(MessagePassing, NormaliserMixin):
"""Base class for edge attribute builders."""

node_attr_name: str = None

def __init__(self, norm: str | None = None, dtype: str = "float32") -> None:
super().__init__()
self.norm = norm
self.dtype = dtype
if self.node_attr_name is None:
error_msg = f"Class {self.__class__.__name__} must define 'node_attr_name' either as a class attribute or in __init__"
raise TypeError(error_msg)

def subset_node_information(self, source_nodes: NodeStorage, target_nodes: NodeStorage) -> PairTensor:
return source_nodes[self.node_attr_name], target_nodes[self.node_attr_name]
Expand Down Expand Up @@ -83,7 +76,6 @@ def aggregate(self, edge_features: torch.Tensor) -> torch.Tensor:


class BasePositionalBuilder(BaseEdgeAttributeBuilder, ABC):

node_attr_name: str = "x"
_idx_lat: int = 0
_idx_lon: int = 1
Expand Down Expand Up @@ -151,11 +143,16 @@ class BaseAttributeFromNodeBuilder(BooleanBaseEdgeAttributeBuilder, ABC):
"""Base class for propagating an attribute from the nodes to the edges."""

def __init__(self, node_attr_name: str) -> None:
super().__init__()
self.node_attr_name = node_attr_name
super().__init__()

def compute(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor:
return (x_i, x_j)[self.node_idx]
def compute(self, x_i: torch.Tensor, _x_j: torch.Tensor) -> torch.Tensor:
return x_i

def forward(self, x: tuple[NodeStorage, NodeStorage], edge_index: Adj, size: Size = None) -> torch.Tensor:
return self.propagate(
edge_index, x=(x[self.node_idx][self.node_attr_name], x[self.node_idx][self.node_attr_name]), size=size
)


class AttributeFromSourceNode(BaseAttributeFromNodeBuilder):
Expand Down
8 changes: 3 additions & 5 deletions graphs/tests/edges/test_edge_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@

import pytest
import torch
from typing import TYPE_CHECKING

from anemoi.graphs.edges.attributes import AttributeFromSourceNode
from anemoi.graphs.edges.attributes import AttributeFromTargetNode
from anemoi.graphs.edges.attributes import EdgeDirection
from anemoi.graphs.edges.attributes import EdgeLength

if TYPE_CHECKING:
from torch_geometric.data import HeteroData
from torch_geometric.data import HeteroData

TEST_EDGES = ("test_nodes", "to", "test_nodes")

Expand Down Expand Up @@ -52,7 +50,7 @@ def test_edge_attribute_from_node(attribute_builder_cls, graph_nodes_and_edges:
edge_index = graph_nodes_and_edges[TEST_EDGES].edge_index
source_nodes = graph_nodes_and_edges[TEST_EDGES[0]]
target_nodes = graph_nodes_and_edges[TEST_EDGES[2]]
edge_attr = edge_attr_builder.compute(x=(source_nodes, target_nodes), edge_index=edge_index)
edge_attr = edge_attr_builder(x=(source_nodes, target_nodes), edge_index=edge_index)
assert isinstance(edge_attr, torch.Tensor)


Expand All @@ -63,4 +61,4 @@ def test_fail_edge_features(attribute_builder, graph_nodes_and_edges):
edge_index = graph_nodes_and_edges[TEST_EDGES].edge_index
source_nodes = graph_nodes_and_edges[TEST_EDGES[0]]
target_nodes = graph_nodes_and_edges[TEST_EDGES[2]]
attribute_builder.compute(x=(source_nodes, target_nodes), edge_index=edge_index)
attribute_builder(x=(source_nodes, target_nodes), edge_index=edge_index)

0 comments on commit 2d37526

Please sign in to comment.