Skip to content

Commit

Permalink
Documentation of EdgeIndex.matmul (pyg-team#8526)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Dec 4, 2023
1 parent 83ecb0f commit 44c9c12
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 62 deletions.
131 changes: 87 additions & 44 deletions test/data/test_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,13 @@ def test_to_sparse_tensor(device):
@withPackage('torch_sparse')
@pytest.mark.parametrize('reduce', ReduceType.__args__)
@pytest.mark.parametrize('transpose', TRANSPOSE)
def test_torch_sparse_spmm(device, reduce, transpose):
adj = EdgeIndex([[0, 1, 1, 2], [2, 0, 1, 2]], device=device)
@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)
def test_torch_sparse_spmm(device, reduce, transpose, is_undirected):
if is_undirected:
kwargs = dict(is_undirected=True)
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], device=device, **kwargs)
else:
adj = EdgeIndex([[0, 1, 1, 2], [2, 0, 1, 2]], device=device)
adj = adj.sort_by('col' if transpose else 'row').values

# Basic:
Expand Down Expand Up @@ -552,8 +557,13 @@ def test_torch_sparse_spmm(device, reduce, transpose):
@withCUDA
@pytest.mark.parametrize('reduce', ReduceType.__args__)
@pytest.mark.parametrize('transpose', TRANSPOSE)
def test_torch_spmm(device, reduce, transpose):
adj = EdgeIndex([[0, 1, 1, 2], [2, 0, 1, 2]], device=device)
@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)
def test_torch_spmm(device, reduce, transpose, is_undirected):
if is_undirected:
kwargs = dict(is_undirected=True)
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], device=device, **kwargs)
else:
adj = EdgeIndex([[0, 1, 1, 2], [2, 0, 1, 2]], device=device)
adj, perm = adj.sort_by('col' if transpose else 'row')

# Basic:
Expand Down Expand Up @@ -607,29 +617,62 @@ def test_torch_spmm(device, reduce, transpose):
out.backward(grad)


def test_matmul_forward():
x = torch.randn(3, 1)
adj1 = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row')
adj1_dense = adj1.to_dense()
adj2 = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col')
adj2_dense = adj2.to_dense()
@withCUDA
@pytest.mark.parametrize('reduce', ReduceType.__args__)
@pytest.mark.parametrize('transpose', TRANSPOSE)
@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)
def test_spmm(device, reduce, transpose, is_undirected):
if is_undirected:
kwargs = dict(is_undirected=True)
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], device=device, **kwargs)
else:
adj = EdgeIndex([[0, 1, 1, 2], [2, 0, 1, 2]], device=device)
adj = adj.sort_by('col' if transpose else 'row').values

# Basic:
x = torch.randn(3, 1, device=device)

out = adj1 @ x
assert torch.allclose(out, adj1_dense @ x)
out = adj.matmul(x, reduce=reduce, transpose=transpose)
exp = _scatter_spmm(adj, x, None, reduce, transpose)
assert out.allclose(exp)

out = adj1.matmul(x)
assert torch.allclose(out, adj1_dense @ x)
# With non-zero values:
x = torch.randn(3, 1, device=device)
value = torch.rand(adj.size(1), device=device)

out = torch.matmul(adj1, x)
assert torch.allclose(out, adj1_dense @ x)
out = adj.matmul(x, value, reduce=reduce, transpose=transpose)
exp = _scatter_spmm(adj, x, value, reduce, transpose)
assert out.allclose(exp)

if torch_geometric.typing.WITH_PT20:
out = torch.sparse.mm(adj1, x, reduce='sum')
else:
with pytest.raises(TypeError, match="got an unexpected keyword"):
torch.sparse.mm(adj1, x, reduce='sum')
out = torch.sparse.mm(adj1, x)
assert torch.allclose(out, adj1_dense @ x)
# Gradients w.r.t. other:
x1 = torch.randn(3, 1, device=device, requires_grad=True)
x2 = x1.detach().requires_grad_()
grad = torch.randn_like(x1)

out = adj.matmul(x1, reduce=reduce, transpose=transpose)
out.backward(grad)
exp = _scatter_spmm(adj, x2, None, reduce, transpose)
exp.backward(grad)
assert x1.grad.allclose(x2.grad)

# Gradients w.r.t. value:
x = torch.randn(3, 1, device=device)
value1 = torch.rand(adj.size(1), device=device, requires_grad=True)
value2 = value1.detach().requires_grad_()
grad = torch.randn_like(x)

out = adj.matmul(x, value1, reduce=reduce, transpose=transpose)
out.backward(grad)
exp = _scatter_spmm(adj, x, value2, reduce, transpose)
exp.backward(grad)
assert value1.grad.allclose(value2.grad)


def test_spspmm():
adj1 = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row')
adj1_dense = adj1.to_dense()
adj2 = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col')
adj2_dense = adj2.to_dense()

out, value = adj1 @ adj1
assert isinstance(out, EdgeIndex)
Expand All @@ -651,32 +694,32 @@ def test_matmul_forward():
assert torch.allclose(out.to_dense(value), adj2_dense @ adj2_dense)


def test_matmul_input_value():
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row')

x = torch.randn(3, 1)
value = torch.randn(4)

out = adj.matmul(x, input_value=value)
assert torch.allclose(out, adj.to_dense(value) @ x)

@withCUDA
def test_matmul(device):
kwargs = dict(sort_order='row', device=device)
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs)
x = torch.randn(3, 1, device=device)
expected = adj.to_dense() @ x

def test_matmul_backward():
adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row')
out = adj @ x
assert torch.allclose(out, expected)

x1 = torch.randn(3, 1, requires_grad=True)
value = torch.randn(4)
out = adj.matmul(x)
assert torch.allclose(out, expected)

out = adj.matmul(x1, input_value=value)
grad_out = torch.randn_like(out)
out.backward(grad_out)
out = torch.mm(adj, x)
assert torch.allclose(out, expected)

x2 = x1.detach().requires_grad_()
dense_adj = adj.to_dense(value)
out = dense_adj @ x2
out.backward(grad_out)
out = torch.matmul(adj, x)
assert torch.allclose(out, expected)

assert torch.allclose(x1.grad, x2.grad)
if torch_geometric.typing.WITH_PT20:
out = torch.sparse.mm(adj, x, reduce='sum')
else:
with pytest.raises(TypeError, match="got an unexpected keyword"):
torch.sparse.mm(adj, x, reduce='sum')
out = torch.sparse.mm(adj, x)
assert torch.allclose(out, expected)


@withCUDA
Expand Down
83 changes: 65 additions & 18 deletions torch_geometric/data/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,9 +560,9 @@ def to_dense(
r"""Converts :class:`EdgeIndex` into a dense :class:`torch.Tensor`.
Args:
value (torch.Tensor, optional): The values for sparse indices. If
not specified, sparse indices will be assigned a value of
:obj:`1`. (default: :obj:`None`)
value (torch.Tensor, optional): The values for non-zero elements.
If not specified, non-zero elements will be assigned a value of
:obj:`1.0`. (default: :obj:`None`)
fill_value (float, optional): The fill value for remaining elements
in the dense matrix. (default: :obj:`0.0`)
dtype (torch.dtype, optional): The data type of the returned
Expand All @@ -586,9 +586,9 @@ def to_sparse_coo(self, value: Optional[Tensor] = None) -> Tensor:
:class:`torch.sparse_coo_tensor`.
Args:
value (torch.Tensor, optional): The values for sparse indices. If
not specified, sparse indices will be assigned a value of
:obj:`1`. (default: :obj:`None`)
value (torch.Tensor, optional): The values for non-zero elements.
If not specified, non-zero elements will be assigned a value of
:obj:`1.0`. (default: :obj:`None`)
"""
value = self._get_value() if value is None else value
out = torch.sparse_coo_tensor(
Expand All @@ -609,9 +609,9 @@ def to_sparse_csr(self, value: Optional[Tensor] = None) -> Tensor:
:class:`torch.sparse_csr_tensor`.
Args:
value (torch.Tensor, optional): The values for sparse indices. If
not specified, sparse indices will be assigned a value of
:obj:`1`. (default: :obj:`None`)
value (torch.Tensor, optional): The values for non-zero elements.
If not specified, non-zero elements will be assigned a value of
:obj:`1.0`. (default: :obj:`None`)
"""
(rowptr, col), perm = self.get_csr()
value = self._get_value() if value is None else value[perm]
Expand All @@ -630,9 +630,9 @@ def to_sparse_csc(self, value: Optional[Tensor] = None) -> Tensor:
:class:`torch.sparse_csc_tensor`.
Args:
value (torch.Tensor, optional): The values for sparse indices. If
not specified, sparse indices will be assigned a value of
:obj:`1`. (default: :obj:`None`)
value (torch.Tensor, optional): The values for non-zero elements.
If not specified, non-zero elements will be assigned a value of
:obj:`1.0`. (default: :obj:`None`)
"""
if not torch_geometric.typing.WITH_PT112:
raise NotImplementedError(
Expand Down Expand Up @@ -663,9 +663,9 @@ def to_sparse(
layout (torch.layout, optional): The desired sparse layout. One of
:obj:`torch.sparse_coo`, :obj:`torch.sparse_csr`, or
:obj:`torch.sparse_csc`. (default: :obj:`torch.sparse_coo`)
value (torch.Tensor, optional): The values for sparse indices. If
not specified, sparse indices will be assigned a value of
:obj:`1`. (default: :obj:`None`)
value (torch.Tensor, optional): The values for non-zero elements.
If not specified, non-zero elements will be assigned a value of
:obj:`1.0`. (default: :obj:`None`)
"""
if layout is None or layout == torch.sparse_coo:
return self.to_sparse_coo(value)
Expand All @@ -685,7 +685,7 @@ def to_sparse_tensor(
Requires that :obj:`torch-sparse` is installed.
Args:
value (torch.Tensor, optional): The values for sparse indices.
value (torch.Tensor, optional): The values for non-zero elements.
(default: :obj:`None`)
"""
return SparseTensor(
Expand All @@ -706,7 +706,54 @@ def matmul(
reduce: ReduceType = 'sum',
transpose: bool = False,
) -> Union[Tensor, Tuple['EdgeIndex', Tensor]]:
# TODO Add doc-string
r"""Performs a matrix multiplication of the matrices :obj:`input` and
:obj:`other`.
If :obj:`input` is a :math:`(n \times m)` matrix and :obj:`other` is a
:math:`(m \times p)` tensor, then the output will be a
:math:`(n \times p)` tensor.
See :meth:`torch.matmul` for more information.
:obj:`input` is a sparse matrix as denoted by the indices in
:class:`EdgeIndex`, and :obj:`input_value` corresponds to the values
of non-zero elements in :obj:`input`.
If not specified, non-zero elements will be assigned a value of
:obj:`1.0`.
:obj:`other` can either be a dense :class:`torch.Tensor` or a sparse
:class:`EdgeIndex`.
if :obj:`other` is a sparse :class:`EdgeIndex`, then :obj:`other_value`
corresponds to the values of its non-zero elements.
This function additionally accepts an optional :obj:`reduce` argument
that allows specification of an optional reduction operation.
See :meth:`torch.sparse.mm` for more information.
Lastly, the :obj:`transpose` option allows to perform matrix
multiplication where :obj:`input` will be first transposed, *i.e.*:
.. math::
\textrm{input}^{\top} \cdot \textrm{other}
Args:
other (torch.Tensor or EdgeIndex): The second matrix to be
multiplied, which can be sparse or dense.
input_value (torch.Tensor, optional): The values for non-zero
elements of :obj:`input`.
If not specified, non-zero elements will be assigned a value of
:obj:`1.0`. (default: :obj:`None`)
other_value (torch.Tensor, optional): The values for non-zero
elements of :obj:`other` in case it is sparse.
If not specified, non-zero elements will be assigned a value of
:obj:`1.0`. (default: :obj:`None`)
reduce (str, optional): The reduce operation, one of
:obj:`"sum"`/:obj:`"add"`, :obj:`"mean"`,
:obj:`"min"`/:obj:`amin` or :obj:`"max"`/:obj:`amax`.
(default: :obj:`"sum"`)
transpose (bool, optional): If set to :obj:`True`, will perform
matrix multiplication based on the transposed :obj:`input`.
(default: :obj:`False`)
"""
return matmul(self, other, input_value, other_value, reduce, transpose)

@classmethod
Expand Down Expand Up @@ -1182,7 +1229,6 @@ def matmul(
reduce: ReduceType = 'sum',
transpose: bool = False,
) -> Union[Tensor, Tuple[EdgeIndex, Tensor]]:

if reduce not in ReduceType.__args__:
raise NotImplementedError(f"`reduce='{reduce}'` not yet supported")

Expand Down Expand Up @@ -1221,6 +1267,7 @@ def matmul(
return edge_index, out.values()


@implements(torch.mm)
@implements(torch.matmul)
@implements(Tensor.matmul)
def _matmul1(
Expand Down

0 comments on commit 44c9c12

Please sign in to comment.