Skip to content

Commit

Permalink
Add support for XPU device in profileit decorator (pyg-team#8532)
Browse files Browse the repository at this point in the history
This PR adds support for XPU device in `profileit` decorator.
Additionally, `get_gpu_memory_from_ipex` function was provided, which
collects statistics such as: `max_allocated`, `max_reserved` and
`max_active`.
  • Loading branch information
DamianSzwichtenberg authored Dec 8, 2023
1 parent 64fc4c1 commit 0e7458d
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 77 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Addes support for XPU device in `profileit` decorator ([#8532](https://github.com/pyg-team/pytorch_geometric/pull/8532))
- Added `KNNIndex` exclusion logic ([#8573](https://github.com/pyg-team/pytorch_geometric/pull/8573))
- Added warning when calling `dataset.num_classes` on regression problems ([#8550](https://github.com/pyg-team/pytorch_geometric/pull/8550))
- Added relabel node functionality to `dropout_node` ([#8524](https://github.com/pyg-team/pytorch_geometric/pull/8524))
Expand Down
60 changes: 50 additions & 10 deletions test/profile/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_timeit(device):
@onlyCUDA
@onlyOnline
@withPackage('pytorch_memlab')
def test_profileit(get_dataset):
def test_profileit_cuda(get_dataset):
warnings.filterwarnings('ignore', '.*arguments of DataFrame.drop.*')

dataset = get_dataset(name='Cora')
Expand All @@ -55,7 +55,7 @@ def test_profileit(get_dataset):
out_channels=dataset.num_classes).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

@profileit()
@profileit('cuda')
def train(model, x, edge_index, y):
model.train()
optimizer.zero_grad()
Expand All @@ -67,28 +67,68 @@ def train(model, x, edge_index, y):
stats_list = []
for epoch in range(5):
_, stats = train(model, data.x, data.edge_index, data.y)
assert len(stats) == 6
assert stats.time > 0
assert stats.max_allocated_cuda > 0
assert stats.max_reserved_cuda > 0
assert stats.max_active_cuda > 0
assert stats.max_allocated_gpu > 0
assert stats.max_reserved_gpu > 0
assert stats.max_active_gpu > 0
assert stats.nvidia_smi_free_cuda > 0
assert stats.nvidia_smi_used_cuda > 0

if epoch >= 2: # Warm-up
stats_list.append(stats)

stats_summary = get_stats_summary(stats_list)
assert len(stats_summary) == 7
assert stats_summary.time_mean > 0
assert stats_summary.time_std > 0
assert stats_summary.max_allocated_cuda > 0
assert stats_summary.max_reserved_cuda > 0
assert stats_summary.max_active_cuda > 0
assert stats_summary.max_allocated_gpu > 0
assert stats_summary.max_reserved_gpu > 0
assert stats_summary.max_active_gpu > 0
assert stats_summary.min_nvidia_smi_free_cuda > 0
assert stats_summary.max_nvidia_smi_used_cuda > 0


@onlyXPU
def test_profileit_xpu(get_dataset):
warnings.filterwarnings('ignore', '.*arguments of DataFrame.drop.*')

dataset = get_dataset(name='Cora')
data = dataset[0].cuda()
model = GraphSAGE(dataset.num_features, hidden_channels=64, num_layers=3,
out_channels=dataset.num_classes).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

@profileit('xpu')
def train(model, x, edge_index, y):
model.train()
optimizer.zero_grad()
out = model(x, edge_index)
loss = F.cross_entropy(out, y)
loss.backward()
return float(loss)

stats_list = []
for epoch in range(5):
_, stats = train(model, data.x, data.edge_index, data.y)
assert stats.time > 0
assert stats.max_allocated_gpu > 0
assert stats.max_reserved_gpu > 0
assert stats.max_active_gpu > 0
assert not hasattr(stats, 'nvidia_smi_free_cuda')
assert not hasattr(stats, 'nvidia_smi_used_cuda')

if epoch >= 2: # Warm-up
stats_list.append(stats)

stats_summary = get_stats_summary(stats_list)
assert stats_summary.time_mean > 0
assert stats_summary.time_std > 0
assert stats_summary.max_allocated_gpu > 0
assert stats_summary.max_reserved_gpu > 0
assert stats_summary.max_active_gpu > 0
assert not hasattr(stats_summary, 'min_nvidia_smi_free_cuda')
assert not hasattr(stats_summary, 'max_nvidia_smi_used_cuda')


@withCUDA
@onlyOnline
def test_torch_profile(capfd, get_dataset, device):
Expand Down
11 changes: 10 additions & 1 deletion test/profile/test_profile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
get_cpu_memory_from_gc,
get_data_size,
get_gpu_memory_from_gc,
get_gpu_memory_from_ipex,
get_gpu_memory_from_nvidia_smi,
get_model_size,
)
from torch_geometric.profile.utils import (
byte_to_megabyte,
medibyte_to_megabyte,
)
from torch_geometric.testing import onlyCUDA, withPackage
from torch_geometric.testing import onlyCUDA, onlyXPU, withPackage
from torch_geometric.typing import SparseTensor


Expand Down Expand Up @@ -68,6 +69,14 @@ def test_get_gpu_memory_from_nvidia_smi():
assert used_mem >= 0


@onlyXPU
def test_get_gpu_memory_from_ipex():
max_allocated, max_reserved, max_active = get_gpu_memory_from_ipex()
assert max_allocated >= 0
assert max_reserved >= 0
assert max_active >= 0


def test_bytes_function():
assert byte_to_megabyte((1024 * 1024)) == 1.00
assert medibyte_to_megabyte(1 / 1.0485) == 1.00
24 changes: 15 additions & 9 deletions torch_geometric/profile/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
r"""GNN profiling package."""

from .profile import profileit, timeit, get_stats_summary
from .benchmark import benchmark
from .profile import (
trace_handler,
get_stats_summary,
print_time_total,
profileit,
rename_profile_file,
timeit,
torch_profile,
trace_handler,
xpu_profile,
)
from .utils import count_parameters
from .utils import get_model_size
from .utils import get_data_size
from .utils import get_cpu_memory_from_gc
from .utils import get_gpu_memory_from_gc
from .utils import get_gpu_memory_from_nvidia_smi
from .benchmark import benchmark
from .utils import (
count_parameters,
get_cpu_memory_from_gc,
get_data_size,
get_gpu_memory_from_gc,
get_gpu_memory_from_ipex,
get_gpu_memory_from_nvidia_smi,
get_model_size,
)

__all__ = [
'profileit',
Expand All @@ -31,6 +36,7 @@
'get_cpu_memory_from_gc',
'get_gpu_memory_from_gc',
'get_gpu_memory_from_nvidia_smi',
'get_gpu_memory_from_ipex',
'benchmark',
]

Expand Down
Loading

0 comments on commit 0e7458d

Please sign in to comment.