diff --git a/test/data/test_edge_index.py b/test/data/test_edge_index.py index 6e122e50c800..27ff274e7cea 100644 --- a/test/data/test_edge_index.py +++ b/test/data/test_edge_index.py @@ -82,6 +82,17 @@ def test_share_memory(): assert adj._rowptr.is_shared() +def test_contiguous(): + data = torch.tensor([[0, 1], [1, 0], [1, 2], [2, 1]]).t() + + with pytest.raises(ValueError, match="needs to be contiguous"): + EdgeIndex(data) + + adj = EdgeIndex(data.contiguous()).contiguous() + assert isinstance(adj, EdgeIndex) + assert adj.is_contiguous() + + def test_sort_by(): adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row') out = adj.sort_by('row') diff --git a/torch_geometric/data/edge_index.py b/torch_geometric/data/edge_index.py index 4011a6d99e3c..2a2bf38c70e8 100644 --- a/torch_geometric/data/edge_index.py +++ b/torch_geometric/data/edge_index.py @@ -49,6 +49,12 @@ def assert_two_dimensional(tensor: Tensor): f"[2, *] (got {list(tensor.size())})") +def assert_contiguous(tensor: Tensor): + if not tensor.is_contiguous(): + raise ValueError("'EdgeIndex' needs to be contiguous. Please call " + "`edge_index.contiguous()` before proceeding.") + + class EdgeIndex(Tensor): r"""An advanced :obj:`edge_index` representation with additional (meta)data attached. @@ -114,6 +120,7 @@ def __new__( assert isinstance(data, Tensor) assert_valid_dtype(data) assert_two_dimensional(data) + assert_contiguous(data) out = super().__new__(cls, data) @@ -130,6 +137,7 @@ def validate(self) -> 'EdgeIndex': """ assert_valid_dtype(self) assert_two_dimensional(self) + assert_contiguous(self) if self.numel() > 0 and self.min() < 0: raise ValueError(f"'{self.__class__.__name__}' contains negative " @@ -346,6 +354,11 @@ def share_memory_(tensor: EdgeIndex) -> EdgeIndex: return apply_(tensor, Tensor.share_memory_) +@implements(Tensor.contiguous) +def contiguous(tensor: EdgeIndex) -> EdgeIndex: + return apply_(tensor, Tensor.contiguous) + + @implements(torch.cat) def cat( tensors: List[Union[EdgeIndex, Tensor]],