Skip to content

Commit

Permalink
Increase code coverage in EdgeIndex (pyg-team#8531)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Dec 5, 2023
1 parent 1e92ba2 commit 1d3bdc6
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 13 deletions.
90 changes: 90 additions & 0 deletions test/data/test_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_scatter_spmm,
_torch_sparse_spmm,
_TorchSPMM,
set_tuple_item,
)
from torch_geometric.profile import benchmark
from torch_geometric.testing import (
Expand Down Expand Up @@ -48,6 +49,8 @@ def test_basic(dtype, device):
assert adj.dtype == dtype
assert adj.device == device
assert adj.sparse_size() == (3, 3)
assert adj.sparse_size(0) == 3
assert adj.sparse_size(-1) == 3

assert adj.sort_order is None
assert not adj.is_sorted
Expand All @@ -67,6 +70,45 @@ def test_basic(dtype, device):
assert out.device == device


def test_set_tuple_item():
tmp = (0, 1, 2)
assert set_tuple_item(tmp, 0, 3) == (3, 1, 2)
assert set_tuple_item(tmp, 1, 3) == (0, 3, 2)
assert set_tuple_item(tmp, 2, 3) == (0, 1, 3)
with pytest.raises(IndexError, match="tuple index out of range"):
set_tuple_item(tmp, 3, 3)
assert set_tuple_item(tmp, -1, 3) == (0, 1, 3)
assert set_tuple_item(tmp, -2, 3) == (0, 3, 2)
assert set_tuple_item(tmp, -3, 3) == (3, 1, 2)
with pytest.raises(IndexError, match="tuple index out of range"):
set_tuple_item(tmp, -4, 3)


def test_validate():
with pytest.raises(ValueError, match="unsupported data type"):
EdgeIndex([[0.0, 1.0], [1.0, 0.0]])
with pytest.raises(ValueError, match="needs to be two-dimensional"):
EdgeIndex([[[0], [1]], [[1], [0]]])
with pytest.raises(ValueError, match="needs to have a shape of"):
EdgeIndex([[0, 1], [1, 0], [1, 1]])
with pytest.raises(ValueError, match="received a non-symmetric size"):
EdgeIndex([[0, 1], [1, 0]], is_undirected=True, sparse_size=(2, 3))
with pytest.raises(TypeError, match="invalid combination of arguments"):
EdgeIndex(torch.tensor([[0, 1], [1, 0]]), torch.long)
with pytest.raises(TypeError, match="invalid keyword arguments"):
EdgeIndex(torch.tensor([[0, 1], [1, 0]]), dtype=torch.long)
with pytest.raises(ValueError, match="contains negative indices"):
EdgeIndex([[-1, 0], [0, 1]]).validate()
with pytest.raises(ValueError, match="than its number of rows"):
EdgeIndex([[0, 10], [1, 0]], sparse_size=(2, 2)).validate()
with pytest.raises(ValueError, match="than its number of columns"):
EdgeIndex([[0, 1], [10, 0]], sparse_size=(2, 2)).validate()
with pytest.raises(ValueError, match="not sorted by row indices"):
EdgeIndex([[1, 0], [0, 1]], sort_order='row').validate()
with pytest.raises(ValueError, match="not sorted by column indices"):
EdgeIndex([[0, 1], [1, 0]], sort_order='col').validate()


@withCUDA
@pytest.mark.parametrize('dtype', DTYPES)
def test_undirected(dtype, device):
Expand All @@ -78,7 +120,14 @@ def test_undirected(dtype, device):
assert adj.sparse_size() == (None, None)
adj.get_num_rows()
assert adj.sparse_size() == (3, 3)
adj.validate()

adj = EdgeIndex([[0, 1], [1, 0]], sparse_size=(3, None), **kwargs)
assert adj.sparse_size() == (3, 3)
adj.validate()

adj = EdgeIndex([[0, 1], [1, 0]], sparse_size=(None, 3), **kwargs)
assert adj.sparse_size() == (3, 3)
adj.validate()

with pytest.raises(ValueError, match="'EdgeIndex' is not undirected"):
Expand Down Expand Up @@ -164,6 +213,10 @@ def test_to(dtype, device, is_undirected):
assert adj._indptr.device == device
assert adj._T_perm.device == device

out = adj.cpu()
assert isinstance(out, EdgeIndex)
assert out.device == torch.device('cpu')

out = adj.to(torch.int)
assert out.dtype == torch.int
if torch_geometric.typing.WITH_PT20:
Expand Down Expand Up @@ -285,6 +338,10 @@ def test_cat(dtype, device, is_undirected):
args = dict(dtype=dtype, device=device, is_undirected=is_undirected)
adj1 = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3), **args)
adj2 = EdgeIndex([[1, 2, 2, 3], [2, 1, 3, 2]], sparse_size=(4, 4), **args)
adj3 = EdgeIndex([[1, 2, 2, 3], [2, 1, 3, 2]], dtype=dtype, device=device)

out = torch.cat([adj1], dim=1)
assert id(out) == id(adj1)

out = torch.cat([adj1, adj2], dim=1)
assert out.size() == (2, 8)
Expand All @@ -293,6 +350,13 @@ def test_cat(dtype, device, is_undirected):
assert not out.is_sorted
assert out.is_undirected == is_undirected

out = torch.cat([adj1, adj2, adj3], dim=1)
assert out.size() == (2, 12)
assert isinstance(out, EdgeIndex)
assert out.sparse_size() == (None, None)
assert not out.is_sorted
assert not out.is_undirected

out = torch.cat([adj1, adj2], dim=0)
assert out.size() == (4, 4)
assert not isinstance(out, EdgeIndex)
Expand Down Expand Up @@ -388,6 +452,9 @@ def test_getitem(dtype, device, is_undirected):
out = adj[tensor([0], device=device)]
assert not isinstance(out, EdgeIndex)

out = adj[tensor([0], device=device), tensor([0], device=device)]
assert not isinstance(out, EdgeIndex)


@withCUDA
@pytest.mark.parametrize('dtype', DTYPES)
Expand Down Expand Up @@ -426,6 +493,11 @@ def test_to_dense(dtype, device, value_dtype):
def test_to_sparse_coo(dtype, device):
kwargs = dict(dtype=dtype, device=device)
adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], **kwargs)

if torch_geometric.typing.WITH_PT20:
with pytest.raises(ValueError, match="Unexpected tensor layout"):
adj.to_sparse(layout='int64')

if torch_geometric.typing.WITH_PT20:
out = adj.to_sparse(layout=torch.sparse_coo)
else:
Expand All @@ -436,6 +508,7 @@ def test_to_sparse_coo(dtype, device):
assert out.layout == torch.sparse_coo
assert out.size() == (3, 3)
assert adj.equal(out._indices())
assert not out.is_coalesced()

adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], **kwargs)
out = adj.to_sparse_coo()
Expand All @@ -445,6 +518,17 @@ def test_to_sparse_coo(dtype, device):
assert out.layout == torch.sparse_coo
assert out.size() == (3, 3)
assert adj.equal(out._indices())
assert not out.is_coalesced()

adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)
out = adj.to_sparse_coo()
assert isinstance(out, Tensor)
assert out.dtype == torch.float
assert out.device == device
assert out.layout == torch.sparse_coo
assert out.size() == (3, 3)
assert adj.equal(out._indices())
assert out.is_coalesced()


@withCUDA
Expand Down Expand Up @@ -632,6 +716,9 @@ def test_spmm(device, reduce, transpose, is_undirected):
# Basic:
x = torch.randn(3, 1, device=device)

with pytest.raises(ValueError, match="to be sorted by"):
adj.matmul(x, reduce=reduce, transpose=not transpose)

out = adj.matmul(x, reduce=reduce, transpose=transpose)
exp = _scatter_spmm(adj, x, None, reduce, transpose)
assert out.allclose(exp)
Expand All @@ -640,6 +727,9 @@ def test_spmm(device, reduce, transpose, is_undirected):
x = torch.randn(3, 1, device=device)
value = torch.rand(adj.size(1), device=device)

with pytest.raises(ValueError, match="'other_value' not supported"):
adj.matmul(x, reduce=reduce, other_value=value, transpose=transpose)

out = adj.matmul(x, value, reduce=reduce, transpose=transpose)
exp = _scatter_spmm(adj, x, value, reduce, transpose)
assert out.allclose(exp)
Expand Down
35 changes: 22 additions & 13 deletions torch_geometric/data/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
torch.int32,
torch.int64,
}
else:
else: # pragma: no cover
SUPPORTED_DTYPES: Set[torch.dtype] = {
torch.int64,
}
Expand Down Expand Up @@ -83,8 +83,8 @@ def assert_contiguous(tensor: Tensor):

def assert_symmetric(size: Tuple[Optional[int], Optional[int]]):
if size[0] is not None and size[1] is not None and size[0] != size[1]:
raise ValueError("'EdgeIndex' is undirected but received a "
"non-symmetric size")
raise ValueError(f"'EdgeIndex' is undirected but received a "
f"non-symmetric size (got {list(size)})")


def assert_sorted(func):
Expand All @@ -101,7 +101,7 @@ def wrapper(*args, **kwargs):


class EdgeIndex(Tensor):
r"""An COO :obj:`edge_index` tensor with additional (meta)data attached.
r"""A COO :obj:`edge_index` tensor with additional (meta)data attached.
:class:`EdgeIndex` is a :pytorch:`null` class:`torch.Tensor`, that holds an
:obj:`edge_index` representation of shape :obj:`[2, num_edges]`.
Expand Down Expand Up @@ -143,14 +143,14 @@ class EdgeIndex(Tensor):
>>> EdgeIndex([[0, 1, 1, 2],
... [1, 0, 2, 1]])
assert edge_index.is_sorted_by_row
assert not edge_index.is_undirected
assert edge_index.is_undirected
# Flipping order:
edge_index = edge_index.flip(0)
>>> EdgeIndex([[1, 0, 2, 1],
... [0, 1, 1, 2]])
assert edge_index.is_sorted_by_col
assert not edge_index.is_undirected
assert edge_index.is_undirected
# Filtering:
mask = torch.tensor([True, True, True, False])
Expand All @@ -161,7 +161,7 @@ class EdgeIndex(Tensor):
assert not edge_index.is_undirected
# Sparse-Dense Matrix Multiplication:
out = edge_index @ torch.randn(3, 16)
out = edge_index.flip(0) @ torch.randn(3, 16)
assert out.size() == (3, 16)
"""
# See "https://pytorch.org/docs/stable/notes/extending.html"
Expand Down Expand Up @@ -293,6 +293,10 @@ def sparse_size(
r"""The size of the underlying sparse matrix.
If :obj:`dim` is specified, returns an integer holding the size of that
sparse dimension.
Args:
dim (int, optional): The dimension for which to retrieve the size.
(default: :obj:`None`)
"""
if dim is not None:
return self._sparse_size[dim]
Expand Down Expand Up @@ -345,6 +349,10 @@ def get_sparse_size(
Automatically computed and cached when not explicitly set.
If :obj:`dim` is specified, returns an integer holding the size of that
sparse dimension.
Args:
dim (int, optional): The dimension for which to retrieve the size.
(default: :obj:`None`)
"""
if dim is not None:
if self._sparse_size[dim] is not None:
Expand Down Expand Up @@ -466,7 +474,7 @@ def _get_value(self, dtype: Optional[torch.dtype] = None) -> Tensor:
if torch_geometric.typing.WITH_PT20 and not self.is_cuda:
value = torch.ones(1, dtype=dtype, device=self.device)
value = value.expand(self.size(1))
else:
else: # pragma: no cover
value = torch.ones(self.size(1), dtype=dtype, device=self.device)

self._value = value
Expand Down Expand Up @@ -844,7 +852,7 @@ def cpu(tensor: EdgeIndex, *args, **kwargs) -> EdgeIndex:


@implements(Tensor.cuda)
def cuda(tensor: EdgeIndex, *args, **kwargs) -> EdgeIndex:
def cuda(tensor: EdgeIndex, *args, **kwargs) -> EdgeIndex: # pragma: no cover
return apply_(tensor, Tensor.cuda, *args, **kwargs)


Expand Down Expand Up @@ -1098,9 +1106,9 @@ def forward(

if torch_geometric.typing.WITH_PT20 and not other.is_cuda:
return torch.sparse.mm(adj, other, reduce)

assert reduce == 'sum'
return adj @ other
else: # pragma: no cover
assert reduce == 'sum'
return adj @ other

@staticmethod
def backward(
Expand Down Expand Up @@ -1195,7 +1203,8 @@ def _spmm(
raise ValueError(f"'matmul(..., transpose=True)' requires "
f"'{cls_name}' to be sorted by colums")

if torch_geometric.typing.WITH_TORCH_SPARSE and other.is_cuda:
if (torch_geometric.typing.WITH_TORCH_SPARSE # pragma: no cover
and other.is_cuda):
return _torch_sparse_spmm(input, other, value, reduce, transpose)

if value is not None and value.requires_grad:
Expand Down

0 comments on commit 1d3bdc6

Please sign in to comment.