Skip to content

Commit

Permalink
Use optimized implementation in softmax operation (pyg-team#8399)
Browse files Browse the repository at this point in the history
This PR uses optimized `softmax_csr` operation (introduced in [pyg-lib @
264](pyg-team/pyg-lib#264)), when given is a CPU
tensor, and softmax groups are defined via `ptr`.
  • Loading branch information
DamianSzwichtenberg authored Nov 21, 2023
1 parent c7483ac commit af586eb
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Made `utils.softmax` faster via `softmax_csr` ([#8399](https://github.com/pyg-team/pytorch_geometric/pull/8399))
- Made `utils.mask.mask_select` faster ([#8369](https://github.com/pyg-team/pytorch_geometric/pull/8369))
- Update `DistNeighborSampler` for homogeneous graphs ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209), [#8367](https://github.com/pyg-team/pytorch_geometric/pull/8367))
- Update `GraphStore` and `FeatureStore` to support distributed training ([#8083](https://github.com/pyg-team/pytorch_geometric/pull/8083))
Expand Down
15 changes: 9 additions & 6 deletions test/utils/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from torch_geometric.profile import benchmark
from torch_geometric.utils import softmax

CALCULATION_VIA_PTR_AVAILABLE = (torch_geometric.typing.WITH_SOFTMAX
or torch_geometric.typing.WITH_TORCH_SCATTER)


def test_softmax():
src = torch.tensor([1., 1., 1., 1.])
Expand All @@ -13,7 +16,7 @@ def test_softmax():

out = softmax(src, index)
assert out.tolist() == [0.5, 0.5, 1, 1]
if torch_geometric.typing.WITH_TORCH_SCATTER:
if CALCULATION_VIA_PTR_AVAILABLE:
assert softmax(src, None, ptr).tolist() == out.tolist()
else:
with pytest.raises(ImportError):
Expand All @@ -22,7 +25,7 @@ def test_softmax():
src = src.view(-1, 1)
out = softmax(src, index)
assert out.tolist() == [[0.5], [0.5], [1], [1]]
if torch_geometric.typing.WITH_TORCH_SCATTER:
if CALCULATION_VIA_PTR_AVAILABLE:
assert softmax(src, None, ptr).tolist() == out.tolist()

jit = torch.jit.script(softmax)
Expand Down Expand Up @@ -52,22 +55,22 @@ def test_softmax_dim():

src = torch.randn(4)
assert torch.allclose(softmax(src, index, dim=0), src.softmax(dim=0))
if torch_geometric.typing.WITH_TORCH_SCATTER:
if CALCULATION_VIA_PTR_AVAILABLE:
assert torch.allclose(softmax(src, ptr=ptr, dim=0), src.softmax(dim=0))

src = torch.randn(4, 16)
assert torch.allclose(softmax(src, index, dim=0), src.softmax(dim=0))
if torch_geometric.typing.WITH_TORCH_SCATTER:
if CALCULATION_VIA_PTR_AVAILABLE:
assert torch.allclose(softmax(src, ptr=ptr, dim=0), src.softmax(dim=0))

src = torch.randn(4, 4)
assert torch.allclose(softmax(src, index, dim=-1), src.softmax(dim=-1))
if torch_geometric.typing.WITH_TORCH_SCATTER:
if CALCULATION_VIA_PTR_AVAILABLE:
assert torch.allclose(softmax(src, ptr=ptr, dim=-1), src.softmax(-1))

src = torch.randn(4, 4, 16)
assert torch.allclose(softmax(src, index, dim=1), src.softmax(dim=1))
if torch_geometric.typing.WITH_TORCH_SCATTER:
if CALCULATION_VIA_PTR_AVAILABLE:
assert torch.allclose(softmax(src, ptr=ptr, dim=1), src.softmax(dim=1))


Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
WITH_GMM = False
WITH_SEGMM = False
WITH_SAMPLED_OP = hasattr(pyg_lib.ops, 'sampled_add')
WITH_SOFTMAX = hasattr(pyg_lib.ops, 'softmax_csr')
WITH_INDEX_SORT = hasattr(pyg_lib.ops, 'index_sort')
WITH_METIS = hasattr(pyg_lib, 'partition')
WITH_EDGE_TIME_NEIGHBOR_SAMPLE = ('edge_time' in inspect.signature(
Expand All @@ -55,6 +56,7 @@
WITH_GMM = False
WITH_SEGMM = False
WITH_SAMPLED_OP = False
WITH_SOFTMAX = False
WITH_INDEX_SORT = False
WITH_METIS = False
WITH_EDGE_TIME_NEIGHBOR_SAMPLE = False
Expand Down
5 changes: 5 additions & 0 deletions torch_geometric/utils/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from torch import Tensor

import torch_geometric.typing
from torch_geometric.typing import pyg_lib
from torch_geometric.utils import scatter, segment
from torch_geometric.utils.num_nodes import maybe_num_nodes

Expand Down Expand Up @@ -50,6 +52,9 @@ def softmax(
[0.8062, 0.1938, 1.0000, 1.0000]])
"""
if ptr is not None:
if (src.device.type == 'cpu'
and torch_geometric.typing.WITH_SOFTMAX): # pragma: no cover
return pyg_lib.ops.softmax_csr(src, ptr, dim)
dim = dim + src.dim() if dim < 0 else dim
size = ([1] * dim) + [-1]
count = ptr[1:] - ptr[:-1]
Expand Down

0 comments on commit af586eb

Please sign in to comment.