From 2d375267828f118ce5abf1380a235f30daef5af5 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 5 Feb 2025 17:08:05 +0000 Subject: [PATCH] fix --- graphs/src/anemoi/graphs/edges/attributes.py | 33 +++++++++----------- graphs/tests/edges/test_edge_attributes.py | 8 ++--- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/graphs/src/anemoi/graphs/edges/attributes.py b/graphs/src/anemoi/graphs/edges/attributes.py index acd351b6..9b69c801 100644 --- a/graphs/src/anemoi/graphs/edges/attributes.py +++ b/graphs/src/anemoi/graphs/edges/attributes.py @@ -10,7 +10,7 @@ from __future__ import annotations import logging -from abc import ABC +from abc import ABC, ABCMeta import torch from torch_geometric.data.storage import NodeStorage @@ -26,25 +26,18 @@ LOGGER = logging.getLogger(__name__) -class NodeAttributeMeta(type): - def __new__(cls, name: str, bases: tuple, class_dict: dict): - if ABC in bases: - return super().__new__(cls, name, bases, class_dict) - - if "node_attr_name" not in class_dict: - error_msg = f"Class {name} must define 'node_attr_name'" - raise TypeError(error_msg) - - return super().__new__(cls, name, bases, class_dict) - - -class BaseEdgeAttributeBuilder(MessagePassing, NormaliserMixin, metaclass=NodeAttributeMeta): +class BaseEdgeAttributeBuilder(MessagePassing, NormaliserMixin): """Base class for edge attribute builders.""" + node_attr_name: str = None + def __init__(self, norm: str | None = None, dtype: str = "float32") -> None: super().__init__() self.norm = norm self.dtype = dtype + if self.node_attr_name is None: + error_msg = f"Class {self.__class__.__name__} must define 'node_attr_name' either as a class attribute or in __init__" + raise TypeError(error_msg) def subset_node_information(self, source_nodes: NodeStorage, target_nodes: NodeStorage) -> PairTensor: return source_nodes[self.node_attr_name], target_nodes[self.node_attr_name] @@ -83,7 +76,6 @@ def aggregate(self, edge_features: torch.Tensor) -> torch.Tensor: class BasePositionalBuilder(BaseEdgeAttributeBuilder, ABC): - node_attr_name: str = "x" _idx_lat: int = 0 _idx_lon: int = 1 @@ -151,11 +143,16 @@ class BaseAttributeFromNodeBuilder(BooleanBaseEdgeAttributeBuilder, ABC): """Base class for propagating an attribute from the nodes to the edges.""" def __init__(self, node_attr_name: str) -> None: - super().__init__() self.node_attr_name = node_attr_name + super().__init__() - def compute(self, x_i: torch.Tensor, x_j: torch.Tensor) -> torch.Tensor: - return (x_i, x_j)[self.node_idx] + def compute(self, x_i: torch.Tensor, _x_j: torch.Tensor) -> torch.Tensor: + return x_i + + def forward(self, x: tuple[NodeStorage, NodeStorage], edge_index: Adj, size: Size = None) -> torch.Tensor: + return self.propagate( + edge_index, x=(x[self.node_idx][self.node_attr_name], x[self.node_idx][self.node_attr_name]), size=size + ) class AttributeFromSourceNode(BaseAttributeFromNodeBuilder): diff --git a/graphs/tests/edges/test_edge_attributes.py b/graphs/tests/edges/test_edge_attributes.py index 22a8c1b6..b16819b1 100644 --- a/graphs/tests/edges/test_edge_attributes.py +++ b/graphs/tests/edges/test_edge_attributes.py @@ -10,15 +10,13 @@ import pytest import torch -from typing import TYPE_CHECKING from anemoi.graphs.edges.attributes import AttributeFromSourceNode from anemoi.graphs.edges.attributes import AttributeFromTargetNode from anemoi.graphs.edges.attributes import EdgeDirection from anemoi.graphs.edges.attributes import EdgeLength -if TYPE_CHECKING: - from torch_geometric.data import HeteroData +from torch_geometric.data import HeteroData TEST_EDGES = ("test_nodes", "to", "test_nodes") @@ -52,7 +50,7 @@ def test_edge_attribute_from_node(attribute_builder_cls, graph_nodes_and_edges: edge_index = graph_nodes_and_edges[TEST_EDGES].edge_index source_nodes = graph_nodes_and_edges[TEST_EDGES[0]] target_nodes = graph_nodes_and_edges[TEST_EDGES[2]] - edge_attr = edge_attr_builder.compute(x=(source_nodes, target_nodes), edge_index=edge_index) + edge_attr = edge_attr_builder(x=(source_nodes, target_nodes), edge_index=edge_index) assert isinstance(edge_attr, torch.Tensor) @@ -63,4 +61,4 @@ def test_fail_edge_features(attribute_builder, graph_nodes_and_edges): edge_index = graph_nodes_and_edges[TEST_EDGES].edge_index source_nodes = graph_nodes_and_edges[TEST_EDGES[0]] target_nodes = graph_nodes_and_edges[TEST_EDGES[2]] - attribute_builder.compute(x=(source_nodes, target_nodes), edge_index=edge_index) + attribute_builder(x=(source_nodes, target_nodes), edge_index=edge_index)