@@ -50,6 +50,33 @@ def assert_two_dimensional(tensor: Tensor):
50
50
51
51
52
52
class EdgeIndex (Tensor ):
53
+ r"""An advanced :obj:`edge_index` representation with additional (meta)data
54
+ attached.
55
+
56
+ :class:`EdgeIndex` is a :pytorch:`PyTorch` tensor, that holds an
57
+ :obj:`edge_index` representation of shape :obj:`[2, num_edges]`.
58
+ Edges are given as pairwise source and destination node indices in sparse
59
+ COO format.
60
+
61
+ While :class:`EdgeIndex` sub-classes a general :pytorch:`PyTorch` tensor,
62
+ it can hold additional (meta)data, *i.e.*:
63
+
64
+ * :obj:`sparse_size`: The underlying sparse matrix size
65
+ * :obj:`sort_order`: The sort order (if present), either by row or column.
66
+
67
+ Additionally, :class:`EdgeIndex` caches data for fast CSR or CSC conversion
68
+ in case its representation is sorted, such as its :obj:`rowptr` or
69
+ :obj:`colptr`, or the permutation vectors for fast conversion from CSR to
70
+ CSC and vice versa.
71
+ Caches are filled based on demand (*e.g.*, when calling
72
+ :meth:`EdgeIndex.sort_by`), or when explicitly requested via
73
+ :meth:`EdgeIndex.fill_cache`, and are maintained and adjusted over its
74
+ lifespan (*e.g.*, when calling :meth:`EdgeIndex.flip`).
75
+
76
+ This representation ensures for optimal computation in GNN message passing
77
+ schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`
78
+ workflows.
79
+ """
53
80
# See "https://pytorch.org/docs/stable/notes/extending.html"
54
81
# for a basic tutorial on how to subclass `torch.Tensor`.
55
82
@@ -133,6 +160,9 @@ def validate(self) -> 'EdgeIndex':
133
160
return self
134
161
135
162
def fill_cache (self ) -> 'EdgeIndex' :
163
+ r"""Fills the cache with (meta)data information.
164
+ No-op in case :class:`EdgeIndex` is not sorted.
165
+ """
136
166
if self ._sort_order == SortOrder .ROW and self ._rowptr is None :
137
167
if self .num_rows is None :
138
168
self ._sparse_size = (
0 commit comments