Skip to content

Commit

Permalink
Enable distributed link hetero sampling (pyg-team#8722)
Browse files Browse the repository at this point in the history
This PR enables distributed edge sampling for heterogeneous graphs.

**Added:**
- Distributed edge heterogeneous sampling.
- Distributed edge heterogeneous node-level and edge-level temporal
sampling.
- `DistEdgeHeteroSamplerInput` class, which serves as an input data to
the `node_sample` function when for a given input edge there are
different source and target node types.
- unit tests

**Comments:**
- In the case when a given input edge has distinct source and
destination node types it is necessary to handle the data of each of
these types separately, so it is slightly different from the situation
when we have only one input node type.
- This PR depends on:
[pyg-team#8718](pyg-team#8718)

---------

Co-authored-by: JakubPietrakIntel <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <[email protected]>
  • Loading branch information
4 people authored Feb 4, 2024
1 parent c65cff9 commit e4568a2
Show file tree
Hide file tree
Showing 5 changed files with 559 additions and 39 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `ExplainerDataset` will now contain node labels for any motif generator ([#8519](https://github.com/pyg-team/pytorch_geometric/pull/8519))
- Made `utils.softmax` faster via `softmax_csr` ([#8399](https://github.com/pyg-team/pytorch_geometric/pull/8399))
- Made `utils.mask.mask_select` faster ([#8369](https://github.com/pyg-team/pytorch_geometric/pull/8369))
- Update `DistNeighborSampler` ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209), [#8367](https://github.com/pyg-team/pytorch_geometric/pull/8367), [#8375](https://github.com/pyg-team/pytorch_geometric/pull/8375), ([#8624](https://github.com/pyg-team/pytorch_geometric/pull/8624))
- Update `DistNeighborSampler` ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209), [#8367](https://github.com/pyg-team/pytorch_geometric/pull/8367), [#8375](https://github.com/pyg-team/pytorch_geometric/pull/8375), ([#8624](https://github.com/pyg-team/pytorch_geometric/pull/8624), [#8722](https://github.com/pyg-team/pytorch_geometric/pull/8722))
- Update `GraphStore` and `FeatureStore` to support distributed training ([#8083](https://github.com/pyg-team/pytorch_geometric/pull/8083))
- Disallow the usage of `add_self_loops=True` in `GCNConv(normalize=False)` ([#8210](https://github.com/pyg-team/pytorch_geometric/pull/8210))
- Disable device asserts during `torch_geometric.compile` ([#8220](https://github.com/pyg-team/pytorch_geometric/pull/8220))
Expand Down
13 changes: 6 additions & 7 deletions test/distributed/test_dist_link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,19 @@ def dist_link_neighbor_loader_hetero(

for batch in loader:
assert isinstance(batch, HeteroData)
assert (batch[edge_type].input_id.numel() ==
batch[edge_type].batch_size == 10)

assert len(batch.node_types) == 2
for node_type in batch.node_types:
assert torch.equal(batch[node_type].x, batch.x_dict[node_type])
assert batch.x_dict[node_type].size(0) >= 0
assert batch[node_type].n_id.size(0) == batch[node_type].num_nodes

assert len(batch.edge_types) == 4
for edge_type in batch.edge_types:
assert (batch[edge_type].edge_attr.size(0) ==
batch[edge_type].edge_index.size(1))
for key in batch.edge_types:
if key[-1] == 'v0':
assert batch[key].num_sampled_edges[0] > 0
assert batch[key].edge_attr.size(0) == batch[key].num_edges
else:
batch[key].num_sampled_edges[0] == 0
assert loader.channel.empty()


Expand Down Expand Up @@ -208,7 +208,6 @@ def test_dist_link_neighbor_loader_homo(
@pytest.mark.parametrize('async_sampling', [True])
@pytest.mark.parametrize('neg_ratio', [None])
@pytest.mark.parametrize('edge_type', [('v0', 'e0', 'v0')])
@pytest.mark.skip(reason="'sample_from_edges' not yet implemented")
def test_dist_link_neighbor_loader_hetero(
tmp_path,
num_parts,
Expand Down
Loading

0 comments on commit e4568a2

Please sign in to comment.