diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ac322c070bb..96d18f7674a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added end-to-end example for distributed CPU training ([#8713](https://github.com/pyg-team/pytorch_geometric/pull/8713)) - Added PyTorch 2.2 support ([#8857](https://github.com/pyg-team/pyg-lib/pull/8857)) - Added fallback code path for `segment` in case `torch-scatter` is not installed ([#8852](https://github.com/pyg-team/pytorch_geometric/pull/8852)) - Added support for custom node labels in `visualize_graph()` ([#8816](https://github.com/pyg-team/pytorch_geometric/pull/8816)) diff --git a/examples/distributed/pyg/README.md b/examples/distributed/pyg/README.md new file mode 100644 index 000000000000..869a2ca143fe --- /dev/null +++ b/examples/distributed/pyg/README.md @@ -0,0 +1,101 @@ +# Distributed Training with PyG + +**[`torch_geometric.distributed`](https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric/distributed)** implements a scalable solution for distributed GNN training, built exclusively upon PyTorch and PyG. + +Current application can be deployed on a cluster of arbitrary size using multiple CPUs. +PyG native GPU application is under development and will be released soon. + +The solution is designed to effortlessly distribute the training of large-scale graph neural networks across multiple nodes, thanks to the integration of Distributed Data Parallelism (DDP) for model training and Remote Procedure Call (RPC) for efficient sampling and fetching of non-local features. +The design includes a number of custom classes, *i.e.* (1) `DistNeighborSampler` implements CPU sampling algorithms and feature extraction from local and remote data remaining consistent data structure at the output, (2) an integrated `DistLoader` which ensures safe opening & closing of RPC connection between the samplers, and (3) a METIS-based `Partitioner` and many more. + +## Example for distributed training on OGB datasets + +The example provided in `distributed_cpu.py` performs distributed training using [OGB](https://ogb.stanford.edu/) datasets and a `GraphSAGE` model. +The example can run on both homogeneous (`ogbn-products`) and heterogeneous data (`ogbn-mag`). +With minor modifications, the example can be extended to train on `ogbn-papers100m` or any other dataset. + +To run the example please refer to the steps below. + +### Requirements + +- [`torch-geometric>=2.5.0`](https://github.com/pyg-team/pytorch_geometric) and [`pyg-lib>=0.4.0`](https://github.com/pyg-team/pyg-lib) +- Password-less SSH needs to be set up on all the nodes that you are using (see the [Linux SSH manual](https://linuxize.com/post/how-to-setup-passwordless-ssh-login)). +- All nodes need to have a consistent environments installed, specifically `torch` and `pyg-lib` versions must be the same. + You might want to consider using docker containers. +- *\[Optional\]* In some cases Linux firewall might be blocking TCP connection issues. + Ensure that firewall settings allow for all nodes to communicate (see the [Linux firewall manual](https://ubuntu.com/server/docs/security-firewall)). + For this example TCP ports `11111`, `11112` and `11113` should be open (*i.e.* `sudo ufw allow 11111`). + +This example shows how to use distributed PyG to train GNN models in a distributed scenario with multiple CPU nodes. + +### Step 1: Prepare and Partition the Data + +In distributed training, each node in the cluster holds a partition of the graph. +Before the training starts, we partition the dataset into multiple partitions, each of which corresponds to a specific training node. + +Here, we use `ogbn-products` and partition it into two partitions (in default) via the [`partition_graph.py` script](./partition_graph.py) : + +```bash +python partition_graph.py --dataset=ogbn-products --root_dir=../../../data --num_partitions=2 +``` + +**Caution:** Partitioning with METIS is non-deterministic! +All nodes should be able to access the same partition data. +Therefore, generate the partitions on one node and copy the data to all members of the cluster, or place the folder into a shared location. + +The generated partition will have a folder structure as below: + +partition-graph + +### Step 2: Run the Example in Each Training Node + +To run the example, you can execute the commands in each node or use the provided launch script. + +#### Option A: Manual Execution + +You should change the `master_addr` to the IP of `node#0`. +Make sure that the correct `node_rank` is provided, with the master node assigned to rank `0`. +The `dataset_root_dir` should point to the head directory where your partition is placed, *i.e.* `../../data/partitions/ogbn-products/2-parts`: + +```bash +# Node 0: +python distributed_cpu.py \ + --dataset=ogbn-products + --dataset_root_dir= \ + --num_nodes=2 \ + --node_rank=0 \ + --master_addr= + +# Node 1: +python distributed_cpu.py \ + --dataset=ogbn-products + --dataset_root_dir= \ + --num_nodes=2 \ + --node_rank=1 \ + --master_addr= +``` + +## Option B: Launch script + +See [#8241](https://github.com/pyg-team/pytorch_geometric/pull/8241). + +There exists two methods to run the distributed example with one script in one terminal for multiple nodes: + +1. `launch.py`: + ```bash + python launch.py + --workspace {workspace}/pytorch_geometric + --num_nodes 2 + --dataset_root_dir {dataset_dir}/mag/2-parts + --dataset ogbn-mag + --batch_size 1024 + --learning_rate 0.0004 + --part_config {dataset_dir}/mag/2-parts/ogbn-mag-partitions/META.json + --ip_config {workspace}/pytorch_geometric/ip_config.yaml + 'cd /home/user_xxx; source {conda_envs}/bin/activate; cd {workspace}/pytorch_geometric; {conda_envs}/bin/python + {workspace}/pytorch_geometric/examples/pyg/distributed_cpu.py --dataset=ogbn-mag --logging --progress_bar --ddp_port=11111' + ``` +1. `run_dist.sh`: All parameter settings are contained in the `run_dist.sh` script and you just need run with: + ```bash + ./run_dist.sh + ``` diff --git a/examples/distributed/pyg/distributed_cpu.py b/examples/distributed/pyg/distributed_cpu.py new file mode 100644 index 000000000000..a08af906aba5 --- /dev/null +++ b/examples/distributed/pyg/distributed_cpu.py @@ -0,0 +1,471 @@ +import argparse +import os.path as osp +import time +from contextlib import nullcontext + +import torch +import torch.distributed +import torch.nn.functional as F +from torch.nn.parallel import DistributedDataParallel +from tqdm import tqdm + +import torch_geometric.distributed as pyg_dist +from torch_geometric.data import HeteroData +from torch_geometric.distributed import LocalFeatureStore, LocalGraphStore +from torch_geometric.distributed.dist_context import DistContext +from torch_geometric.nn import GraphSAGE, to_hetero + + +@torch.no_grad() +def test( + model, + loader, + dist_context, + device, + epoch, + logfile=None, + num_loader_threads=10, + progress_bar=True, +): + def test_homo(batch): + out = model(batch.x, batch.edge_index)[:batch.batch_size] + y_pred = out.argmax(dim=-1) + y_true = batch.y[:batch.batch_size] + return y_pred, y_true + + def test_hetero(batch): + batch_size = batch['paper'].batch_size + out = model(batch.x_dict, batch.edge_index_dict) + out = out['paper'][:batch_size] + y_pred = out.argmax(dim=-1) + y_true = batch['paper'].y[:batch_size] + return y_pred, y_true + + total_examples = total_correct = 0 + + if loader.num_workers > 0: + context = loader.enable_multithreading(num_loader_threads) + else: + context = nullcontext() + + with context: + if progress_bar: + loader = tqdm(loader, desc=f'[Node {dist_context.rank}] Test') + + start_time = batch_time = time.time() + for i, batch in enumerate(loader): + batch = batch.to(device) + if isinstance(batch, HeteroData): + y_pred, y_true = test_hetero(batch) + else: + y_pred, y_true = test_homo(batch) + + total_correct += int((y_pred == y_true).sum()) + total_examples += y_pred.size(0) + batch_acc = int((y_pred == y_true).sum()) / y_pred.size(0) + + result = (f'[Node {dist_context.rank}] Test: epoch={epoch}, ' + f'it={i}, acc={batch_acc:.4f}, ' + f'time={(time.time() - batch_time):.4f}') + batch_time = time.time() + + if logfile: + log = open(logfile, 'a+') + log.write(f'{result}\n') + log.close() + + if not progress_bar: + print(result) + + total_acc = total_correct / total_examples + print(f'[Node {dist_context.rank}] Test epoch {epoch} END: ' + f'acc={total_acc:.4f}, time={(time.time() - start_time):.2f}') + torch.distributed.barrier() + + +def training( + model, + loader, + optimizer, + dist_context, + device, + epoch, + logfile=None, + num_loader_threads=10, + progress_bar=True, +): + def train_homo(batch): + batch = batch.to(device) + optimizer.zero_grad() + out = model(batch.x, batch.edge_index)[:batch.batch_size] + loss = F.cross_entropy(out, batch.y[:batch.batch_size]) + return loss, batch.batch_size + + def train_hetero(batch): + batch_size = batch['paper'].batch_size + batch = batch.to(device, 'edge_index') + optimizer.zero_grad() + out = model(batch.x_dict, batch.edge_index_dict) + out = out['paper'][:batch_size] + target = batch['paper'].y[:batch_size] + loss = F.cross_entropy(out, target) + return loss, batch_size + + total_examples = total_loss = 0 + + if loader.num_workers > 0: + context = loader.enable_multithreading(num_loader_threads) + else: + context = nullcontext() + + with context: + if progress_bar: + loader = tqdm(loader, desc=f'[Node {dist_context.rank}] Train') + + start_time = batch_time = time.time() + for i, batch in enumerate(loader): + batch = batch.to(device) + if isinstance(batch, HeteroData): + loss, batch_size = train_hetero(batch) + else: + loss, batch_size = train_homo(batch) + + loss.backward() + optimizer.step() + + total_loss += float(loss) * batch_size + total_examples += batch_size + + result = (f'[Node {dist_context.rank}] Train: epoch={epoch}, ' + f'it={i}, loss={loss:.4f}, ' + f'time={(time.time() - batch_time):.4f}') + batch_time = time.time() + + if logfile: + log = open(logfile, 'a+') + log.write(f'{result}\n') + log.close() + + if not progress_bar: + print(result) + + print(f'[Node {dist_context.rank}] Train epoch {epoch} END: ' + f'loss={total_loss/total_examples:.4f}, ' + f'time={(time.time() - start_time):.2f}') + torch.distributed.barrier() + + +def run_proc( + local_proc_rank: int, + num_nodes: int, + node_rank: int, + dataset: str, + dataset_root_dir: str, + master_addr: str, + ddp_port: int, + train_loader_port: int, + test_loader_port: int, + num_epochs: int, + batch_size: int, + num_neighbors: str, + async_sampling: bool, + concurrency: int, + num_workers: int, + num_loader_threads: int, + progress_bar: bool, + logfile: str, +): + is_hetero = dataset == 'ogbn-mag' + + print('--- Loading data partition files ...') + root_dir = osp.join(osp.dirname(osp.realpath(__file__)), dataset_root_dir) + node_label_file = osp.join(root_dir, f'{dataset}-label', 'label.pt') + train_idx = torch.load( + osp.join( + root_dir, + f'{dataset}-train-partitions', + f'partition{node_rank}.pt', + )) + test_idx = torch.load( + osp.join( + root_dir, + f'{dataset}-test-partitions', + f'partition{node_rank}.pt', + )) + + if is_hetero: + train_idx = ('paper', train_idx) + test_idx = ('paper', test_idx) + + # Load partition into local graph store: + graph = LocalGraphStore.from_partition( + osp.join(root_dir, f'{dataset}-partitions'), node_rank) + # Load partition into local feature store: + feature = LocalFeatureStore.from_partition( + osp.join(root_dir, f'{dataset}-partitions'), node_rank) + feature.labels = torch.load(node_label_file) + partition_data = (feature, graph) + print(f'Partition metadata: {graph.meta}') + + # Initialize distributed context: + current_ctx = DistContext( + world_size=num_nodes, + rank=node_rank, + global_world_size=num_nodes, + global_rank=node_rank, + group_name='distributed-sage-supervised-Node', + ) + current_device = torch.device('cpu') + + print('--- Initialize DDP training group ...') + # Initialize DDP training process group. + torch.distributed.init_process_group( + backend='gloo', + rank=current_ctx.rank, + world_size=current_ctx.world_size, + init_method='tcp://{}:{}'.format(master_addr, ddp_port), + ) + print('--- Initialize distributed loaders ...') + + num_neighbors = [int(i) for i in num_neighbors.split(',')] + # Create distributed neighbor loader for training: + train_loader = pyg_dist.DistNeighborLoader( + data=partition_data, + input_nodes=train_idx, + current_ctx=current_ctx, + device=current_device, + num_neighbors=num_neighbors, + shuffle=True, + drop_last=True, + persistent_workers=num_workers > 0, + batch_size=batch_size, + num_workers=num_workers, + master_addr=master_addr, + master_port=train_loader_port, + concurrency=concurrency, + async_sampling=async_sampling, + ) + # Create distributed neighbor loader for testing: + test_loader = pyg_dist.DistNeighborLoader( + data=partition_data, + input_nodes=test_idx, + current_ctx=current_ctx, + device=current_device, + num_neighbors=num_neighbors, + shuffle=False, + drop_last=False, + persistent_workers=num_workers > 0, + batch_size=batch_size, + num_workers=num_workers, + master_addr=master_addr, + master_port=test_loader_port, + concurrency=concurrency, + async_sampling=async_sampling, + ) + + print('--- Initialize model ...') + # Define model and optimizer. + model = GraphSAGE( + in_channels=128 if is_hetero else 100, # num_features + hidden_channels=256, + num_layers=len(num_neighbors), + out_channels=349 if is_hetero else 47, # num_classes in dataset + ).to(current_device) + + if is_hetero: # Turn model into a heterogeneous variant: + metadata = [ + graph.meta['node_types'], + [tuple(e) for e in graph.meta['edge_types']], + ] + model = to_hetero(model, metadata).to(current_device) + torch.distributed.barrier() + + # Enable DDP: + model = DistributedDataParallel(model, find_unused_parameters=is_hetero) + optimizer = torch.optim.Adam(model.parameters(), lr=0.0004) + torch.distributed.barrier() + + # Train and test: + print(f'--- Start training for {num_epochs} epochs ...') + for epoch in range(1, num_epochs + 1): + print(f'Train epoch {epoch}/{num_epochs}:') + model.train() + training( + model, + train_loader, + optimizer, + current_ctx, + current_device, + epoch, + logfile, + num_loader_threads, + progress_bar, + ) + + if epoch % 5 == 0: + print(f'Test epoch {epoch}/{num_epochs}:') + model.eval() + test( + model, + test_loader, + current_ctx, + current_device, + epoch, + logfile, + num_loader_threads, + progress_bar, + ) + print(f'--- [Node {current_ctx.rank}] Closing ---') + torch.distributed.destroy_process_group() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Arguments for distributed training') + + parser.add_argument( + '--dataset', + type=str, + default='ogbn-products', + choices=['ogbn-products', 'ogbn-mag'], + help='Name of ogbn dataset: (ogbn-products, ogbn-mag)', + ) + parser.add_argument( + '--dataset_root_dir', + type=str, + default='../../../data/partitions/ogbn-products/2-parts', + help='The root directory (relative path) of partitioned dataset', + ) + parser.add_argument( + '--num_nodes', + type=int, + default=2, + help='Number of distributed nodes', + ) + parser.add_argument( + '--num_neighbors', + type=str, + default='15,10,5', + help='Number of node neighbors sampled at each layer', + ) + parser.add_argument( + '--node_rank', + type=int, + default=0, + help='The current node rank', + ) + parser.add_argument( + '--num_epochs', + type=int, + default=100, + help='The number of training epochs', + ) + parser.add_argument( + '--batch_size', + type=int, + default=1024, + help='Batch size for training and testing', + ) + parser.add_argument( + '--num_workers', + type=int, + default=4, + help='Number of sampler sub-processes', + ) + parser.add_argument( + '--num_loader_threads', + type=int, + default=10, + help='Number of threads used for each sampler sub-process', + ) + parser.add_argument( + '--concurrency', + type=int, + default=4, + help='Number of maximum concurrent RPC for each sampler', + ) + parser.add_argument( + '--async_sampling', + type=bool, + default=True, + help='Whether sampler processes RPC requests asynchronously', + ) + parser.add_argument( + '--master_addr', + type=str, + default='localhost', + help='The master address for RPC initialization', + ) + parser.add_argument( + '--ddp_port', + type=int, + default=11111, + help="The port used for PyTorch's DDP communication", + ) + parser.add_argument( + '--train_loader_port', + type=int, + default=11112, + help='The port used for RPC communication across training samplers', + ) + parser.add_argument( + '--test_loader_port', + type=int, + default=11113, + help='The port used for RPC communication across test samplers', + ) + parser.add_argument('--logging', action='store_true') + parser.add_argument('--progress_bar', action='store_true') + + args = parser.parse_args() + + print('--- Distributed training example with GraphSAGE ---') + print(f'* total nodes: {args.num_nodes}') + print(f'* node rank: {args.node_rank}') + print(f'* dataset: {args.dataset}') + print(f'* dataset root dir: {args.dataset_root_dir}') + print(f'* epochs: {args.num_epochs}') + print(f'* batch size: {args.batch_size}') + print(f'* number of sampler workers: {args.num_workers}') + print(f'* master addr: {args.master_addr}') + print(f'* training process group master port: {args.ddp_port}') + print(f'* training loader master port: {args.train_loader_port}') + print(f'* testing loader master port: {args.test_loader_port}') + print(f'* RPC asynchronous processing: {args.async_sampling}') + print(f'* RPC concurrency: {args.concurrency}') + print(f'* loader multithreading: {args.num_loader_threads}') + print(f'* logging enabled: {args.logging}') + print(f'* progress bars enabled: {args.progress_bar}') + + if args.logging: + logfile = f'dist_cpu-node{args.node_rank}.txt' + with open(logfile, 'a+') as log: + log.write(f'\n--- Inputs: {str(args)}') + else: + logfile = None + + print('--- Launching training processes ...') + torch.multiprocessing.spawn( + run_proc, + args=( + args.num_nodes, + args.node_rank, + args.dataset, + args.dataset_root_dir, + args.master_addr, + args.ddp_port, + args.train_loader_port, + args.test_loader_port, + args.num_epochs, + args.batch_size, + args.num_neighbors, + args.async_sampling, + args.concurrency, + args.num_workers, + args.num_loader_threads, + args.progress_bar, + logfile, + ), + join=True, + ) + print('--- Finished training processes ...') diff --git a/test/distributed/test_dist_link_neighbor_loader.py b/test/distributed/test_dist_link_neighbor_loader.py index 0599606b151b..c8d75909c7b7 100644 --- a/test/distributed/test_dist_link_neighbor_loader.py +++ b/test/distributed/test_dist_link_neighbor_loader.py @@ -15,32 +15,12 @@ LocalGraphStore, Partitioner, ) -from torch_geometric.distributed.partition import load_partition_info from torch_geometric.testing import onlyDistributedTest def create_dist_data(tmp_path: str, rank: int): graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank) feat_store = LocalFeatureStore.from_partition(tmp_path, pid=rank) - ( - meta, - num_partitions, - partition_idx, - node_pb, - edge_pb, - ) = load_partition_info(tmp_path, rank) - - graph_store.partition_idx = partition_idx - graph_store.num_partitions = num_partitions - graph_store.node_pb = node_pb - graph_store.edge_pb = edge_pb - graph_store.meta = meta - - feat_store.partition_idx = partition_idx - feat_store.num_partitions = num_partitions - feat_store.node_feat_pb = node_pb - feat_store.edge_feat_pb = edge_pb - feat_store.meta = meta return feat_store, graph_store diff --git a/test/distributed/test_dist_link_neighbor_sampler.py b/test/distributed/test_dist_link_neighbor_sampler.py index 7b9b76128dfe..eb1f28f2ef13 100644 --- a/test/distributed/test_dist_link_neighbor_sampler.py +++ b/test/distributed/test_dist_link_neighbor_sampler.py @@ -16,7 +16,6 @@ ) from torch_geometric.distributed.dist_context import DistContext from torch_geometric.distributed.event_loop import ConcurrentEventLoop -from torch_geometric.distributed.partition import load_partition_info from torch_geometric.distributed.rpc import init_rpc, shutdown_rpc from torch_geometric.sampler import EdgeSamplerInput, NeighborSampler from torch_geometric.sampler.neighbor_sampler import edge_sample @@ -79,18 +78,6 @@ def create_hetero_data(tmp_path: str, rank: int): graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank) other_graph_store = LocalGraphStore.from_partition(tmp_path, int(not rank)) feature_store = LocalFeatureStore.from_partition(tmp_path, pid=rank) - ( - meta, - num_partitions, - partition_idx, - node_pb, - edge_pb, - ) = load_partition_info(tmp_path, rank) - graph_store.partition_idx = feature_store.partition_idx = partition_idx - graph_store.num_partitions = feature_store.num_partitions = num_partitions - graph_store.node_pb = feature_store.node_feat_pb = node_pb - graph_store.edge_pb = feature_store.edge_feat_pb = edge_pb - graph_store.meta = feature_store.meta = meta return (feature_store, graph_store), other_graph_store diff --git a/test/distributed/test_dist_neighbor_loader.py b/test/distributed/test_dist_neighbor_loader.py index 31ef5103d3e8..79dfb9ca53ef 100644 --- a/test/distributed/test_dist_neighbor_loader.py +++ b/test/distributed/test_dist_neighbor_loader.py @@ -14,31 +14,12 @@ LocalGraphStore, Partitioner, ) -from torch_geometric.distributed.partition import load_partition_info from torch_geometric.testing import onlyDistributedTest def create_dist_data(tmp_path: str, rank: int): graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank) feat_store = LocalFeatureStore.from_partition(tmp_path, pid=rank) - ( - meta, - num_partitions, - partition_idx, - node_pb, - edge_pb, - ) = load_partition_info(tmp_path, rank) - graph_store.partition_idx = partition_idx - graph_store.num_partitions = num_partitions - graph_store.node_pb = node_pb - graph_store.edge_pb = edge_pb - graph_store.meta = meta - - feat_store.partition_idx = partition_idx - feat_store.num_partitions = num_partitions - feat_store.node_feat_pb = node_pb - feat_store.edge_feat_pb = edge_pb - feat_store.meta = meta return feat_store, graph_store diff --git a/test/distributed/test_dist_neighbor_sampler.py b/test/distributed/test_dist_neighbor_sampler.py index ac3ea9c29f0a..a2e0be7181bd 100644 --- a/test/distributed/test_dist_neighbor_sampler.py +++ b/test/distributed/test_dist_neighbor_sampler.py @@ -15,7 +15,6 @@ ) from torch_geometric.distributed.dist_context import DistContext from torch_geometric.distributed.event_loop import ConcurrentEventLoop -from torch_geometric.distributed.partition import load_partition_info from torch_geometric.distributed.rpc import init_rpc, shutdown_rpc from torch_geometric.sampler import NeighborSampler, NodeSamplerInput from torch_geometric.sampler.neighbor_sampler import node_sample @@ -79,18 +78,6 @@ def create_hetero_data( ): graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank) feature_store = LocalFeatureStore.from_partition(tmp_path, pid=rank) - ( - meta, - num_partitions, - partition_idx, - node_pb, - edge_pb, - ) = load_partition_info(tmp_path, rank) - graph_store.partition_idx = feature_store.partition_idx = partition_idx - graph_store.num_partitions = feature_store.num_partitions = num_partitions - graph_store.node_pb = feature_store.node_feat_pb = node_pb - graph_store.edge_pb = feature_store.edge_feat_pb = edge_pb - graph_store.meta = feature_store.meta = meta return feature_store, graph_store diff --git a/torch_geometric/distributed/dist_loader.py b/torch_geometric/distributed/dist_loader.py index c3a82f58caa1..b8ecdf2da291 100644 --- a/torch_geometric/distributed/dist_loader.py +++ b/torch_geometric/distributed/dist_loader.py @@ -3,6 +3,7 @@ import os from typing import Any, Optional, Union +import torch.distributed import torch.multiprocessing as mp from torch_geometric.distributed import DistNeighborSampler @@ -101,6 +102,9 @@ def reset_channel(self, channel=None): logging.debug(f'{self} Resetting msg channel') while not self.channel.empty(): self.channel.get_nowait() + + torch.distributed.barrier() + self.channel = channel or mp.Queue() self.dist_sampler.channel = self.channel @@ -142,10 +146,17 @@ def __repr__(self) -> str: return f'{self.__class__.__name__}(pid={self.pid})' def __enter__(self) -> DataLoaderIterator: + # fetch a single batch for init + self._prefetch_old = self.prefetch_factor + self.prefetch_factor = 1 self._iterator = self._get_iterator() return self._iterator def __exit__(self, *args) -> None: - del self._iterator if self.channel: self.reset_channel() + if self._iterator: + del self._iterator + torch.distributed.barrier() + self._iterator = None + self.prefetch_factor = self._prefetch_old diff --git a/torch_geometric/distributed/dist_neighbor_loader.py b/torch_geometric/distributed/dist_neighbor_loader.py index 9e6021f4c631..56557f8be44c 100644 --- a/torch_geometric/distributed/dist_neighbor_loader.py +++ b/torch_geometric/distributed/dist_neighbor_loader.py @@ -55,6 +55,7 @@ def __init__( time_attr: Optional[str] = None, transform: Optional[Callable] = None, concurrency: int = 1, + num_rpc_threads: int = 16, filter_per_worker: Optional[bool] = False, async_sampling: bool = True, device: Optional[torch.device] = None, @@ -93,6 +94,7 @@ def __init__( master_port=master_port, current_ctx=current_ctx, dist_sampler=dist_sampler, + num_rpc_threads=num_rpc_threads, **kwargs, ) NodeLoader.__init__( diff --git a/torch_geometric/distributed/dist_neighbor_sampler.py b/torch_geometric/distributed/dist_neighbor_sampler.py index 9b6e94e896ac..ae1f69505a47 100644 --- a/torch_geometric/distributed/dist_neighbor_sampler.py +++ b/torch_geometric/distributed/dist_neighbor_sampler.py @@ -999,9 +999,9 @@ async def _collate_fn( # Collect node labels of input node type. node_labels = self.feature_store.labels if node_labels is not None: - nlabels = node_labels[output.node[self.input_type]] - else: - nlabels = None + nlabels = { + self.input_type: node_labels[output.node[self.input_type]] + } # Collect node features. if output.node is not None: for ntype in output.node.keys(): @@ -1041,8 +1041,8 @@ async def _collate_fn( is_node_feat=True, index=output.node) nfeats = await to_asyncio_future(fut) nfeats = nfeats.to(torch.device("cpu")) - # else: - efeats = None + else: + nfeats = None # Collect edge features. if output.edge is not None and self.with_edge_attr: fut = self.feature_store.lookup_features( diff --git a/torch_geometric/distributed/local_feature_store.py b/torch_geometric/distributed/local_feature_store.py index f59442084b17..c7a0e4210e16 100644 --- a/torch_geometric/distributed/local_feature_store.py +++ b/torch_geometric/distributed/local_feature_store.py @@ -1,5 +1,4 @@ import copy -import json import os.path as osp from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union @@ -9,6 +8,7 @@ from torch_geometric.data import FeatureStore, TensorAttr from torch_geometric.data.feature_store import _FieldStatus +from torch_geometric.distributed.partition import load_partition_info from torch_geometric.distributed.rpc import ( RPCCallBase, RPCRouter, @@ -397,11 +397,21 @@ def from_hetero_data( @classmethod def from_partition(cls, root: str, pid: int) -> 'LocalFeatureStore': - with open(osp.join(root, 'META.json'), 'r') as f: - meta = json.load(f) - part_dir = osp.join(root, f'part_{pid}') assert osp.exists(part_dir) + feat_store = cls() + ( + meta, + num_partitions, + partition_idx, + node_pb, + edge_pb, + ) = load_partition_info(root, pid) + feat_store.num_partitions = num_partitions + feat_store.partition_idx = partition_idx + feat_store.node_feat_pb = node_pb + feat_store.edge_feat_pb = edge_pb + feat_store.meta = meta node_feats: Optional[Dict[str, Any]] = None if osp.exists(osp.join(part_dir, 'node_feats.pt')): @@ -411,8 +421,6 @@ def from_partition(cls, root: str, pid: int) -> 'LocalFeatureStore': if osp.exists(osp.join(part_dir, 'edge_feats.pt')): edge_feats = torch.load(osp.join(part_dir, 'edge_feats.pt')) - feat_store = cls() - if not meta['is_hetero'] and node_feats is not None: feat_store.put_global_id(node_feats['global_id'], group_name=None) for key, value in node_feats['feats'].items(): diff --git a/torch_geometric/distributed/local_graph_store.py b/torch_geometric/distributed/local_graph_store.py index 1de90c30039d..5b7366591064 100644 --- a/torch_geometric/distributed/local_graph_store.py +++ b/torch_geometric/distributed/local_graph_store.py @@ -1,4 +1,3 @@ -import json import os.path as osp from typing import Any, Dict, List, Optional, Tuple, Union @@ -6,6 +5,7 @@ from torch import Tensor from torch_geometric.data import EdgeAttr, GraphStore +from torch_geometric.distributed.partition import load_partition_info from torch_geometric.typing import EdgeTensorType, EdgeType, NodeType from torch_geometric.utils import sort_edge_index @@ -169,14 +169,23 @@ def from_hetero_data( @classmethod def from_partition(cls, root: str, pid: int) -> 'LocalGraphStore': - with open(osp.join(root, 'META.json'), 'r') as f: - meta = json.load(f) - part_dir = osp.join(root, f'part_{pid}') assert osp.exists(part_dir) + graph_store = cls() + ( + meta, + num_partitions, + partition_idx, + node_pb, + edge_pb, + ) = load_partition_info(root, pid) + graph_store.num_partitions = num_partitions + graph_store.partition_idx = partition_idx + graph_store.node_pb = node_pb + graph_store.edge_pb = edge_pb + graph_store.meta = meta graph_data = torch.load(osp.join(part_dir, 'graph.pt')) - graph_store = cls() graph_store.is_sorted = meta['is_sorted'] if not meta['is_hetero']: diff --git a/torch_geometric/distributed/partition.py b/torch_geometric/distributed/partition.py index 6ed3350e99df..806b243cec7c 100644 --- a/torch_geometric/distributed/partition.py +++ b/torch_geometric/distributed/partition.py @@ -7,8 +7,8 @@ import torch +import torch_geometric.distributed as pyg_dist from torch_geometric.data import Data, HeteroData -from torch_geometric.distributed.utils import as_str from torch_geometric.loader.cluster import ClusterData from torch_geometric.typing import Dict, EdgeType, EdgeTypeStr, NodeType, Tuple from torch_geometric.utils import index_sort @@ -375,13 +375,13 @@ def load_partition_info( node_pb_dir = osp.join(root_dir, 'node_map') for ntype in meta['node_types']: node_pb_dict[ntype] = torch.load( - osp.join(node_pb_dir, f'{as_str(ntype)}.pt')) + osp.join(node_pb_dir, f'{pyg_dist.utils.as_str(ntype)}.pt')) edge_pb_dict = {} edge_pb_dir = osp.join(root_dir, 'edge_map') for etype in meta['edge_types']: edge_pb_dict[tuple(etype)] = torch.load( - osp.join(edge_pb_dir, f'{as_str(etype)}.pt')) + osp.join(edge_pb_dir, f'{pyg_dist.utils.as_str(etype)}.pt')) return (meta, num_partitions, partition_idx, node_pb_dict, edge_pb_dict) diff --git a/torch_geometric/distributed/rpc.py b/torch_geometric/distributed/rpc.py index dede952239b3..5182854456e0 100644 --- a/torch_geometric/distributed/rpc.py +++ b/torch_geometric/distributed/rpc.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional +from torch._C._distributed_rpc import _is_current_rpc_agent_set from torch.distributed import rpc from torch_geometric.distributed.dist_context import DistContext, DistRole @@ -11,7 +12,6 @@ def rpc_is_initialized() -> bool: - from torch._C._distributed_rpc import _is_current_rpc_agent_set return _is_current_rpc_agent_set() @@ -75,11 +75,11 @@ def shutdown_rpc(id: str = None, graceful: bool = True, timeout: float = 240.0): with _rpc_init_lock: if rpc_is_initialized(): - global_barrier(timeout=timeout) + logging.info(f"Shutdown RPC in {id}" + f"{' gracefully' if graceful else ''}") rpc.shutdown(graceful, timeout) - logging.debug(f'Closed RPC in {id} (graceful={graceful})') else: - logging.error(f'RPC in {id} not initialized.') + logging.info(f'RPC in {id} not initialized.') class RPCRouter: diff --git a/torch_geometric/distributed/utils.py b/torch_geometric/distributed/utils.py index 1bc1b8b48fc9..a20981f61ca7 100644 --- a/torch_geometric/distributed/utils.py +++ b/torch_geometric/distributed/utils.py @@ -6,7 +6,8 @@ from torch import Tensor from torch_geometric.data import HeteroData -from torch_geometric.distributed import LocalFeatureStore, LocalGraphStore +from torch_geometric.distributed.local_feature_store import LocalFeatureStore +from torch_geometric.distributed.local_graph_store import LocalGraphStore from torch_geometric.sampler import SamplerOutput from torch_geometric.typing import EdgeType, NodeType @@ -87,7 +88,6 @@ def remove_duplicates( batch: Optional[Tensor] = None, disjoint: bool = False, ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: - num_nodes = node.numel() node_combined = torch.cat([node, out.node]) @@ -152,18 +152,18 @@ def filter_dist_store( required_node_attrs.append(attr) data[attr.group_name].num_nodes = attr.index.size(0) - if nfeats is not None: + if nfeats: for attr in required_node_attrs: if nfeats[attr.group_name] is not None: data[attr.group_name][attr.attr_name] = nfeats[attr.group_name] - if efeats is not None: + if efeats: for attr in required_edge_attrs: if efeats[attr.edge_type] is not None: data[attr.edge_type].edge_attr = efeats[attr.edge_type] - if nlabels is not None: - data[input_type].y = nlabels + if nlabels: + data[input_type].y = nlabels[input_type] return data diff --git a/torch_geometric/loader/node_loader.py b/torch_geometric/loader/node_loader.py index 3a16bf618ab2..2c040ae1f9c2 100644 --- a/torch_geometric/loader/node_loader.py +++ b/torch_geometric/loader/node_loader.py @@ -212,9 +212,11 @@ def filter_fn( if (self.node_sampler.__class__.__name__ == 'DistNeighborSampler'): import torch_geometric.distributed as dist + data = dist.utils.filter_dist_store( *self.data, out.node, out.row, out.col, out.edge, - self.custom_cls, out.metadata) + self.custom_cls, out.metadata, + self.input_data.input_type) else: data = filter_custom_hetero_store( # *self.data, out.node, out.row, out.col, out.edge, @@ -261,8 +263,5 @@ def _get_iterator(self) -> Iterator: # Execute `filter_fn` in the main process: return DataLoaderIterator(super()._get_iterator(), self.filter_fn) - def __enter__(self): - return self - def __repr__(self) -> str: return f'{self.__class__.__name__}()'