diff --git a/CHANGELOG.md b/CHANGELOG.md index 552c404b41cb..61fb23157166 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Addes support for XPU device in `profileit` decorator ([#8532](https://github.com/pyg-team/pytorch_geometric/pull/8532)) - Added `KNNIndex` exclusion logic ([#8573](https://github.com/pyg-team/pytorch_geometric/pull/8573)) - Added warning when calling `dataset.num_classes` on regression problems ([#8550](https://github.com/pyg-team/pytorch_geometric/pull/8550)) - Added relabel node functionality to `dropout_node` ([#8524](https://github.com/pyg-team/pytorch_geometric/pull/8524)) diff --git a/test/profile/test_profile.py b/test/profile/test_profile.py index 24142d1073b9..8b8e4d5ad199 100644 --- a/test/profile/test_profile.py +++ b/test/profile/test_profile.py @@ -46,7 +46,7 @@ def test_timeit(device): @onlyCUDA @onlyOnline @withPackage('pytorch_memlab') -def test_profileit(get_dataset): +def test_profileit_cuda(get_dataset): warnings.filterwarnings('ignore', '.*arguments of DataFrame.drop.*') dataset = get_dataset(name='Cora') @@ -55,7 +55,7 @@ def test_profileit(get_dataset): out_channels=dataset.num_classes).cuda() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) - @profileit() + @profileit('cuda') def train(model, x, edge_index, y): model.train() optimizer.zero_grad() @@ -67,11 +67,10 @@ def train(model, x, edge_index, y): stats_list = [] for epoch in range(5): _, stats = train(model, data.x, data.edge_index, data.y) - assert len(stats) == 6 assert stats.time > 0 - assert stats.max_allocated_cuda > 0 - assert stats.max_reserved_cuda > 0 - assert stats.max_active_cuda > 0 + assert stats.max_allocated_gpu > 0 + assert stats.max_reserved_gpu > 0 + assert stats.max_active_gpu > 0 assert stats.nvidia_smi_free_cuda > 0 assert stats.nvidia_smi_used_cuda > 0 @@ -79,16 +78,57 @@ def train(model, x, edge_index, y): stats_list.append(stats) stats_summary = get_stats_summary(stats_list) - assert len(stats_summary) == 7 assert stats_summary.time_mean > 0 assert stats_summary.time_std > 0 - assert stats_summary.max_allocated_cuda > 0 - assert stats_summary.max_reserved_cuda > 0 - assert stats_summary.max_active_cuda > 0 + assert stats_summary.max_allocated_gpu > 0 + assert stats_summary.max_reserved_gpu > 0 + assert stats_summary.max_active_gpu > 0 assert stats_summary.min_nvidia_smi_free_cuda > 0 assert stats_summary.max_nvidia_smi_used_cuda > 0 +@onlyXPU +def test_profileit_xpu(get_dataset): + warnings.filterwarnings('ignore', '.*arguments of DataFrame.drop.*') + + dataset = get_dataset(name='Cora') + data = dataset[0].cuda() + model = GraphSAGE(dataset.num_features, hidden_channels=64, num_layers=3, + out_channels=dataset.num_classes).cuda() + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + @profileit('xpu') + def train(model, x, edge_index, y): + model.train() + optimizer.zero_grad() + out = model(x, edge_index) + loss = F.cross_entropy(out, y) + loss.backward() + return float(loss) + + stats_list = [] + for epoch in range(5): + _, stats = train(model, data.x, data.edge_index, data.y) + assert stats.time > 0 + assert stats.max_allocated_gpu > 0 + assert stats.max_reserved_gpu > 0 + assert stats.max_active_gpu > 0 + assert not hasattr(stats, 'nvidia_smi_free_cuda') + assert not hasattr(stats, 'nvidia_smi_used_cuda') + + if epoch >= 2: # Warm-up + stats_list.append(stats) + + stats_summary = get_stats_summary(stats_list) + assert stats_summary.time_mean > 0 + assert stats_summary.time_std > 0 + assert stats_summary.max_allocated_gpu > 0 + assert stats_summary.max_reserved_gpu > 0 + assert stats_summary.max_active_gpu > 0 + assert not hasattr(stats_summary, 'min_nvidia_smi_free_cuda') + assert not hasattr(stats_summary, 'max_nvidia_smi_used_cuda') + + @withCUDA @onlyOnline def test_torch_profile(capfd, get_dataset, device): diff --git a/test/profile/test_profile_utils.py b/test/profile/test_profile_utils.py index 4f38aa21a24f..48901121e665 100644 --- a/test/profile/test_profile_utils.py +++ b/test/profile/test_profile_utils.py @@ -7,6 +7,7 @@ get_cpu_memory_from_gc, get_data_size, get_gpu_memory_from_gc, + get_gpu_memory_from_ipex, get_gpu_memory_from_nvidia_smi, get_model_size, ) @@ -14,7 +15,7 @@ byte_to_megabyte, medibyte_to_megabyte, ) -from torch_geometric.testing import onlyCUDA, withPackage +from torch_geometric.testing import onlyCUDA, onlyXPU, withPackage from torch_geometric.typing import SparseTensor @@ -68,6 +69,14 @@ def test_get_gpu_memory_from_nvidia_smi(): assert used_mem >= 0 +@onlyXPU +def test_get_gpu_memory_from_ipex(): + max_allocated, max_reserved, max_active = get_gpu_memory_from_ipex() + assert max_allocated >= 0 + assert max_reserved >= 0 + assert max_active >= 0 + + def test_bytes_function(): assert byte_to_megabyte((1024 * 1024)) == 1.00 assert medibyte_to_megabyte(1 / 1.0485) == 1.00 diff --git a/torch_geometric/profile/__init__.py b/torch_geometric/profile/__init__.py index 0a7bb1663ac7..833ee657d0e7 100644 --- a/torch_geometric/profile/__init__.py +++ b/torch_geometric/profile/__init__.py @@ -1,20 +1,25 @@ r"""GNN profiling package.""" -from .profile import profileit, timeit, get_stats_summary +from .benchmark import benchmark from .profile import ( - trace_handler, + get_stats_summary, print_time_total, + profileit, rename_profile_file, + timeit, torch_profile, + trace_handler, xpu_profile, ) -from .utils import count_parameters -from .utils import get_model_size -from .utils import get_data_size -from .utils import get_cpu_memory_from_gc -from .utils import get_gpu_memory_from_gc -from .utils import get_gpu_memory_from_nvidia_smi -from .benchmark import benchmark +from .utils import ( + count_parameters, + get_cpu_memory_from_gc, + get_data_size, + get_gpu_memory_from_gc, + get_gpu_memory_from_ipex, + get_gpu_memory_from_nvidia_smi, + get_model_size, +) __all__ = [ 'profileit', @@ -31,6 +36,7 @@ 'get_cpu_memory_from_gc', 'get_gpu_memory_from_gc', 'get_gpu_memory_from_nvidia_smi', + 'get_gpu_memory_from_ipex', 'benchmark', ] diff --git a/torch_geometric/profile/profile.py b/torch_geometric/profile/profile.py index 748d7f57b7ad..54767d5dd359 100644 --- a/torch_geometric/profile/profile.py +++ b/torch_geometric/profile/profile.py @@ -2,7 +2,8 @@ import pathlib import time from contextlib import ContextDecorator, contextmanager -from typing import Any, List, NamedTuple, Tuple +from dataclasses import dataclass +from typing import Any, List, Tuple, Union import torch from torch.autograd.profiler import EventList @@ -10,40 +11,54 @@ from torch_geometric.profile.utils import ( byte_to_megabyte, + get_gpu_memory_from_ipex, get_gpu_memory_from_nvidia_smi, ) -class Stats(NamedTuple): +@dataclass +class GPUStats: time: float - max_allocated_cuda: float - max_reserved_cuda: float - max_active_cuda: float + max_allocated_gpu: float + max_reserved_gpu: float + max_active_gpu: float + + +@dataclass +class CUDAStats(GPUStats): nvidia_smi_free_cuda: float nvidia_smi_used_cuda: float -class StatsSummary(NamedTuple): +@dataclass +class GPUStatsSummary: time_mean: float time_std: float - max_allocated_cuda: float - max_reserved_cuda: float - max_active_cuda: float + max_allocated_gpu: float + max_reserved_gpu: float + max_active_gpu: float + + +@dataclass +class CUDAStatsSummary(GPUStatsSummary): min_nvidia_smi_free_cuda: float max_nvidia_smi_used_cuda: float -def profileit(): # pragma: no cover +def profileit(device: str): # pragma: no cover r"""A decorator to facilitate profiling a function, *e.g.*, obtaining training runtime and memory statistics of a specific model on a specific dataset. - Returns a :obj:`Stats` object with the attributes :obj:`time`, - :obj:`max_active_cuda`, :obj:`max_reserved_cuda`, :obj:`max_active_cuda`, - :obj:`nvidia_smi_free_cuda`, :obj:`nvidia_smi_used_cuda`. + Returns a :obj:`GPUStats` if :obj:`device` is :obj:`xpu` or extended + object :obj:`CUDAStats`, if :obj:`device` is :obj:`cuda`. + + Args: + device (str): Target device for profiling. Options are: + :obj:`cuda` and obj:`xpu`. .. code-block:: python - @profileit() + @profileit("cuda") def train(model, optimizer, x, edge_index, y): optimizer.zero_grad() out = model(x, edge_index) @@ -55,56 +70,71 @@ def train(model, optimizer, x, edge_index, y): loss, stats = train(model, x, edge_index, y) """ def decorator(func): - def wrapper(*args, **kwargs) -> Tuple[Any, Stats]: - from pytorch_memlab import LineProfiler - + def wrapper( + *args, **kwargs + ) -> Union[Tuple[Any, GPUStats], Tuple[Any, CUDAStats]]: model = args[0] if not isinstance(model, torch.nn.Module): raise AttributeError( 'First argument for profiling needs to be torch.nn.Module') + if device not in ['cuda', 'xpu']: + raise AttributeError( + "The profiling decorator supports only CUDA and " + "XPU devices") - device = None + device_id = None for arg in list(args) + list(kwargs.values()): if isinstance(arg, torch.Tensor): - device = arg.get_device() + device_id = arg.get_device() break - if device is None: + if device_id is None: raise AttributeError( - "Could not infer CUDA device from the args in the " + "Could not infer GPU device from the args in the " "function being profiled") - if device == -1: + if device_id == -1: raise RuntimeError( "The profiling decorator does not support profiling " - "on non CUDA devices") + "on non GPU devices") + + is_cuda = device == 'cuda' + torch_gpu = torch.cuda if is_cuda else torch.xpu - # Init `pytorch_memlab` for analyzing the model forward pass: - line_profiler = LineProfiler(target_gpu=device) - line_profiler.enable() - line_profiler.add_function(args[0].forward) + # `pytorch_memlab` supports only CUDA devices + if is_cuda: + from pytorch_memlab import LineProfiler - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) + # Init `pytorch_memlab` for analyzing the model forward pass: + line_profiler = LineProfiler(target_gpu=device_id) + line_profiler.enable() + line_profiler.add_function(args[0].forward) + + start = torch_gpu.Event(enable_timing=True) + end = torch_gpu.Event(enable_timing=True) start.record() out = func(*args, **kwargs) end.record() - torch.cuda.synchronize() + torch_gpu.synchronize() time = start.elapsed_time(end) / 1000 - # Get the global memory statistics collected by `pytorch_memlab`: - memlab = read_from_memlab(line_profiler) - max_allocated_cuda, max_reserved_cuda, max_active_cuda = memlab - line_profiler.disable() + if is_cuda: + # Get the global memory statistics collected + # by `pytorch_memlab`: + memlab = read_from_memlab(line_profiler) + max_allocated, max_reserved, max_active = memlab + line_profiler.disable() - # Get additional information from `nvidia-smi`: - free_cuda, used_cuda = get_gpu_memory_from_nvidia_smi( - device=device) + # Get additional information from `nvidia-smi`: + free_cuda, used_cuda = get_gpu_memory_from_nvidia_smi( + device=device_id) - stats = Stats(time, max_allocated_cuda, max_reserved_cuda, - max_active_cuda, free_cuda, used_cuda) - - return out, stats + stats = CUDAStats(time, max_allocated, max_reserved, + max_active, free_cuda, used_cuda) + return out, stats + else: + stats = GPUStats(time, *get_gpu_memory_from_ipex()) + return out, stats return wrapper @@ -162,28 +192,37 @@ def reset(self): self.__enter__() -def get_stats_summary(stats_list: List[Stats]): # pragma: no cover +def get_stats_summary( + stats_list: Union[List[GPUStats], List[CUDAStats]] +) -> Union[GPUStatsSummary, CUDAStatsSummary]: # pragma: no cover r"""Creates a summary of collected runtime and memory statistics. - Returns a :obj:`StatsSummary` object with the attributes :obj:`time_mean`, - :obj:`time_std`, - :obj:`max_active_cuda`, :obj:`max_reserved_cuda`, :obj:`max_active_cuda`, - :obj:`min_nvidia_smi_free_cuda`, :obj:`max_nvidia_smi_used_cuda`. + Returns a :obj:`GPUStatsSummary` if list of :obj:`GPUStats` was passed, + otherwise (list of :obj:`CUDAStats` was passed), + returns a :obj:`CUDAStatsSummary`. Args: - stats_list (List[Stats]): A list of :obj:`Stats` objects, as returned - by :meth:`~torch_geometric.profile.profileit`. + stats_list (Union[List[GPUStats], List[CUDAStats]]): A list of + :obj:`GPUStats` or :obj:`CUDAStats` objects, as returned by + :meth:`~torch_geometric.profile.profileit`. """ - return StatsSummary( + # calculate common statistics + kwargs = dict( time_mean=float(torch.tensor([s.time for s in stats_list]).mean()), time_std=float(torch.tensor([s.time for s in stats_list]).std()), - max_allocated_cuda=max([s.max_allocated_cuda for s in stats_list]), - max_reserved_cuda=max([s.max_reserved_cuda for s in stats_list]), - max_active_cuda=max([s.max_active_cuda for s in stats_list]), - min_nvidia_smi_free_cuda=min( - [s.nvidia_smi_free_cuda for s in stats_list]), - max_nvidia_smi_used_cuda=max( - [s.nvidia_smi_used_cuda for s in stats_list]), - ) + max_allocated_gpu=max([s.max_allocated_gpu for s in stats_list]), + max_reserved_gpu=max([s.max_reserved_gpu for s in stats_list]), + max_active_gpu=max([s.max_active_gpu for s in stats_list])) + + if all(isinstance(s, GPUStats) for s in stats_list): + return GPUStatsSummary(**kwargs) + else: + return CUDAStatsSummary( + **kwargs, + min_nvidia_smi_free_cuda=min( + [s.nvidia_smi_free_cuda for s in stats_list]), + max_nvidia_smi_used_cuda=max( + [s.nvidia_smi_used_cuda for s in stats_list]), + ) ############################################################################### diff --git a/torch_geometric/profile/utils.py b/torch_geometric/profile/utils.py index 71035a6326c0..83c57a4ff7d5 100644 --- a/torch_geometric/profile/utils.py +++ b/torch_geometric/profile/utils.py @@ -135,6 +135,28 @@ def get_gpu_memory_from_nvidia_smi( # pragma: no cover return free_mem, used_mem +def get_gpu_memory_from_ipex( + device: int = 0, + digits=2) -> Tuple[float, float, float]: # pragma: no cover + r"""Returns the XPU memory statistics. + + Args: + device (int, optional): The GPU device identifier. (default: :obj:`0`) + digits (int): The number of decimals to use for megabytes. + (default: :obj:`2`) + """ + import intel_extension_for_pytorch as ipex + stats = ipex.xpu.memory_stats_as_nested_dict(device) + max_allocated = stats['allocated_bytes']['all']['peak'] + max_reserved = stats['reserved_bytes']['all']['peak'] + max_active = stats['active_bytes']['all']['peak'] + max_allocated = byte_to_megabyte(max_allocated, digits) + max_reserved = byte_to_megabyte(max_reserved, digits) + max_active = byte_to_megabyte(max_active, digits) + ipex.xpu.reset_peak_memory_stats() + return max_allocated, max_reserved, max_active + + ###############################################################################