diff --git a/test/data/test_edge_index.py b/test/data/test_edge_index.py index 14d5a30e1d76..2388213bdfb2 100644 --- a/test/data/test_edge_index.py +++ b/test/data/test_edge_index.py @@ -14,6 +14,7 @@ _scatter_spmm, _torch_sparse_spmm, _TorchSPMM, + set_tuple_item, ) from torch_geometric.profile import benchmark from torch_geometric.testing import ( @@ -48,6 +49,8 @@ def test_basic(dtype, device): assert adj.dtype == dtype assert adj.device == device assert adj.sparse_size() == (3, 3) + assert adj.sparse_size(0) == 3 + assert adj.sparse_size(-1) == 3 assert adj.sort_order is None assert not adj.is_sorted @@ -67,6 +70,45 @@ def test_basic(dtype, device): assert out.device == device +def test_set_tuple_item(): + tmp = (0, 1, 2) + assert set_tuple_item(tmp, 0, 3) == (3, 1, 2) + assert set_tuple_item(tmp, 1, 3) == (0, 3, 2) + assert set_tuple_item(tmp, 2, 3) == (0, 1, 3) + with pytest.raises(IndexError, match="tuple index out of range"): + set_tuple_item(tmp, 3, 3) + assert set_tuple_item(tmp, -1, 3) == (0, 1, 3) + assert set_tuple_item(tmp, -2, 3) == (0, 3, 2) + assert set_tuple_item(tmp, -3, 3) == (3, 1, 2) + with pytest.raises(IndexError, match="tuple index out of range"): + set_tuple_item(tmp, -4, 3) + + +def test_validate(): + with pytest.raises(ValueError, match="unsupported data type"): + EdgeIndex([[0.0, 1.0], [1.0, 0.0]]) + with pytest.raises(ValueError, match="needs to be two-dimensional"): + EdgeIndex([[[0], [1]], [[1], [0]]]) + with pytest.raises(ValueError, match="needs to have a shape of"): + EdgeIndex([[0, 1], [1, 0], [1, 1]]) + with pytest.raises(ValueError, match="received a non-symmetric size"): + EdgeIndex([[0, 1], [1, 0]], is_undirected=True, sparse_size=(2, 3)) + with pytest.raises(TypeError, match="invalid combination of arguments"): + EdgeIndex(torch.tensor([[0, 1], [1, 0]]), torch.long) + with pytest.raises(TypeError, match="invalid keyword arguments"): + EdgeIndex(torch.tensor([[0, 1], [1, 0]]), dtype=torch.long) + with pytest.raises(ValueError, match="contains negative indices"): + EdgeIndex([[-1, 0], [0, 1]]).validate() + with pytest.raises(ValueError, match="than its number of rows"): + EdgeIndex([[0, 10], [1, 0]], sparse_size=(2, 2)).validate() + with pytest.raises(ValueError, match="than its number of columns"): + EdgeIndex([[0, 1], [10, 0]], sparse_size=(2, 2)).validate() + with pytest.raises(ValueError, match="not sorted by row indices"): + EdgeIndex([[1, 0], [0, 1]], sort_order='row').validate() + with pytest.raises(ValueError, match="not sorted by column indices"): + EdgeIndex([[0, 1], [1, 0]], sort_order='col').validate() + + @withCUDA @pytest.mark.parametrize('dtype', DTYPES) def test_undirected(dtype, device): @@ -78,7 +120,14 @@ def test_undirected(dtype, device): assert adj.sparse_size() == (None, None) adj.get_num_rows() assert adj.sparse_size() == (3, 3) + adj.validate() + + adj = EdgeIndex([[0, 1], [1, 0]], sparse_size=(3, None), **kwargs) + assert adj.sparse_size() == (3, 3) + adj.validate() + adj = EdgeIndex([[0, 1], [1, 0]], sparse_size=(None, 3), **kwargs) + assert adj.sparse_size() == (3, 3) adj.validate() with pytest.raises(ValueError, match="'EdgeIndex' is not undirected"): @@ -164,6 +213,10 @@ def test_to(dtype, device, is_undirected): assert adj._indptr.device == device assert adj._T_perm.device == device + out = adj.cpu() + assert isinstance(out, EdgeIndex) + assert out.device == torch.device('cpu') + out = adj.to(torch.int) assert out.dtype == torch.int if torch_geometric.typing.WITH_PT20: @@ -285,6 +338,10 @@ def test_cat(dtype, device, is_undirected): args = dict(dtype=dtype, device=device, is_undirected=is_undirected) adj1 = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3), **args) adj2 = EdgeIndex([[1, 2, 2, 3], [2, 1, 3, 2]], sparse_size=(4, 4), **args) + adj3 = EdgeIndex([[1, 2, 2, 3], [2, 1, 3, 2]], dtype=dtype, device=device) + + out = torch.cat([adj1], dim=1) + assert id(out) == id(adj1) out = torch.cat([adj1, adj2], dim=1) assert out.size() == (2, 8) @@ -293,6 +350,13 @@ def test_cat(dtype, device, is_undirected): assert not out.is_sorted assert out.is_undirected == is_undirected + out = torch.cat([adj1, adj2, adj3], dim=1) + assert out.size() == (2, 12) + assert isinstance(out, EdgeIndex) + assert out.sparse_size() == (None, None) + assert not out.is_sorted + assert not out.is_undirected + out = torch.cat([adj1, adj2], dim=0) assert out.size() == (4, 4) assert not isinstance(out, EdgeIndex) @@ -388,6 +452,9 @@ def test_getitem(dtype, device, is_undirected): out = adj[tensor([0], device=device)] assert not isinstance(out, EdgeIndex) + out = adj[tensor([0], device=device), tensor([0], device=device)] + assert not isinstance(out, EdgeIndex) + @withCUDA @pytest.mark.parametrize('dtype', DTYPES) @@ -426,6 +493,11 @@ def test_to_dense(dtype, device, value_dtype): def test_to_sparse_coo(dtype, device): kwargs = dict(dtype=dtype, device=device) adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], **kwargs) + + if torch_geometric.typing.WITH_PT20: + with pytest.raises(ValueError, match="Unexpected tensor layout"): + adj.to_sparse(layout='int64') + if torch_geometric.typing.WITH_PT20: out = adj.to_sparse(layout=torch.sparse_coo) else: @@ -436,6 +508,7 @@ def test_to_sparse_coo(dtype, device): assert out.layout == torch.sparse_coo assert out.size() == (3, 3) assert adj.equal(out._indices()) + assert not out.is_coalesced() adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], **kwargs) out = adj.to_sparse_coo() @@ -445,6 +518,17 @@ def test_to_sparse_coo(dtype, device): assert out.layout == torch.sparse_coo assert out.size() == (3, 3) assert adj.equal(out._indices()) + assert not out.is_coalesced() + + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + out = adj.to_sparse_coo() + assert isinstance(out, Tensor) + assert out.dtype == torch.float + assert out.device == device + assert out.layout == torch.sparse_coo + assert out.size() == (3, 3) + assert adj.equal(out._indices()) + assert out.is_coalesced() @withCUDA @@ -632,6 +716,9 @@ def test_spmm(device, reduce, transpose, is_undirected): # Basic: x = torch.randn(3, 1, device=device) + with pytest.raises(ValueError, match="to be sorted by"): + adj.matmul(x, reduce=reduce, transpose=not transpose) + out = adj.matmul(x, reduce=reduce, transpose=transpose) exp = _scatter_spmm(adj, x, None, reduce, transpose) assert out.allclose(exp) @@ -640,6 +727,9 @@ def test_spmm(device, reduce, transpose, is_undirected): x = torch.randn(3, 1, device=device) value = torch.rand(adj.size(1), device=device) + with pytest.raises(ValueError, match="'other_value' not supported"): + adj.matmul(x, reduce=reduce, other_value=value, transpose=transpose) + out = adj.matmul(x, value, reduce=reduce, transpose=transpose) exp = _scatter_spmm(adj, x, value, reduce, transpose) assert out.allclose(exp) diff --git a/torch_geometric/data/edge_index.py b/torch_geometric/data/edge_index.py index 2ca583ca5cb2..d8024692d259 100644 --- a/torch_geometric/data/edge_index.py +++ b/torch_geometric/data/edge_index.py @@ -26,7 +26,7 @@ torch.int32, torch.int64, } -else: +else: # pragma: no cover SUPPORTED_DTYPES: Set[torch.dtype] = { torch.int64, } @@ -83,8 +83,8 @@ def assert_contiguous(tensor: Tensor): def assert_symmetric(size: Tuple[Optional[int], Optional[int]]): if size[0] is not None and size[1] is not None and size[0] != size[1]: - raise ValueError("'EdgeIndex' is undirected but received a " - "non-symmetric size") + raise ValueError(f"'EdgeIndex' is undirected but received a " + f"non-symmetric size (got {list(size)})") def assert_sorted(func): @@ -101,7 +101,7 @@ def wrapper(*args, **kwargs): class EdgeIndex(Tensor): - r"""An COO :obj:`edge_index` tensor with additional (meta)data attached. + r"""A COO :obj:`edge_index` tensor with additional (meta)data attached. :class:`EdgeIndex` is a :pytorch:`null` class:`torch.Tensor`, that holds an :obj:`edge_index` representation of shape :obj:`[2, num_edges]`. @@ -143,14 +143,14 @@ class EdgeIndex(Tensor): >>> EdgeIndex([[0, 1, 1, 2], ... [1, 0, 2, 1]]) assert edge_index.is_sorted_by_row - assert not edge_index.is_undirected + assert edge_index.is_undirected # Flipping order: edge_index = edge_index.flip(0) >>> EdgeIndex([[1, 0, 2, 1], ... [0, 1, 1, 2]]) assert edge_index.is_sorted_by_col - assert not edge_index.is_undirected + assert edge_index.is_undirected # Filtering: mask = torch.tensor([True, True, True, False]) @@ -161,7 +161,7 @@ class EdgeIndex(Tensor): assert not edge_index.is_undirected # Sparse-Dense Matrix Multiplication: - out = edge_index @ torch.randn(3, 16) + out = edge_index.flip(0) @ torch.randn(3, 16) assert out.size() == (3, 16) """ # See "https://pytorch.org/docs/stable/notes/extending.html" @@ -293,6 +293,10 @@ def sparse_size( r"""The size of the underlying sparse matrix. If :obj:`dim` is specified, returns an integer holding the size of that sparse dimension. + + Args: + dim (int, optional): The dimension for which to retrieve the size. + (default: :obj:`None`) """ if dim is not None: return self._sparse_size[dim] @@ -345,6 +349,10 @@ def get_sparse_size( Automatically computed and cached when not explicitly set. If :obj:`dim` is specified, returns an integer holding the size of that sparse dimension. + + Args: + dim (int, optional): The dimension for which to retrieve the size. + (default: :obj:`None`) """ if dim is not None: if self._sparse_size[dim] is not None: @@ -466,7 +474,7 @@ def _get_value(self, dtype: Optional[torch.dtype] = None) -> Tensor: if torch_geometric.typing.WITH_PT20 and not self.is_cuda: value = torch.ones(1, dtype=dtype, device=self.device) value = value.expand(self.size(1)) - else: + else: # pragma: no cover value = torch.ones(self.size(1), dtype=dtype, device=self.device) self._value = value @@ -844,7 +852,7 @@ def cpu(tensor: EdgeIndex, *args, **kwargs) -> EdgeIndex: @implements(Tensor.cuda) -def cuda(tensor: EdgeIndex, *args, **kwargs) -> EdgeIndex: +def cuda(tensor: EdgeIndex, *args, **kwargs) -> EdgeIndex: # pragma: no cover return apply_(tensor, Tensor.cuda, *args, **kwargs) @@ -1098,9 +1106,9 @@ def forward( if torch_geometric.typing.WITH_PT20 and not other.is_cuda: return torch.sparse.mm(adj, other, reduce) - - assert reduce == 'sum' - return adj @ other + else: # pragma: no cover + assert reduce == 'sum' + return adj @ other @staticmethod def backward( @@ -1195,7 +1203,8 @@ def _spmm( raise ValueError(f"'matmul(..., transpose=True)' requires " f"'{cls_name}' to be sorted by colums") - if torch_geometric.typing.WITH_TORCH_SPARSE and other.is_cuda: + if (torch_geometric.typing.WITH_TORCH_SPARSE # pragma: no cover + and other.is_cuda): return _torch_sparse_spmm(input, other, value, reduce, transpose) if value is not None and value.requires_grad: