Skip to content

Commit c04ef35

Browse files
author
Ashutosh Tiwari
committed
added gpu tests, need to work on expanding python version range
1 parent 45c43d8 commit c04ef35

File tree

5 files changed

+52
-1
lines changed

5 files changed

+52
-1
lines changed

Makefile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,9 @@ format:
3131
@poetry install --only lint
3232
@poetry run black .
3333
@poetry run pre-commit run --all-files
34+
35+
.PHONY: run_tests
36+
run_tests:
37+
@echo "running tests..."
38+
@poetry install --only main --only test -vvv
39+
@poetry run pytest -q tests

graph_ml/utils/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import os, sys
2+
import torch
3+
4+
5+
OS = os.name
6+
DEVICE_TYPE = 'cpu'
7+
if OS == 'posix':
8+
if torch.cuda.is_available():
9+
DEVICE_TYPE = 'cuda'
10+
elif torch.backends.mps.is_available():
11+
DEVICE_TYPE = 'mps'
12+
13+
DEVICE = torch.device(DEVICE_TYPE)
14+
15+
GPU_AVAILABLE = DEVICE_TYPE in ['cuda', 'mps']
16+

graph_ml/utils/gpu_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import torch
2+
3+
OS =

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,8 @@ pytest = "^8.0.2"
6464
[tool.poetry.group.lint.dependencies]
6565
pre-commit = "^3.6.2"
6666
black = "^24.2.0"
67+
68+
69+
70+
# poetry add git+ssh://[email protected]/thunderock/graph_ml.git
71+
# poetry add whl_url

tests/test_gpu_available.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,25 @@
11
from __future__ import annotations
2+
import os
23

34
import pytest
4-
from graph_ml import add
5+
from graph_ml.utils import config
6+
7+
OS = os.name
8+
9+
def test_target_os():
10+
assert OS == "posix"
11+
12+
13+
def test_device_type():
14+
assert config.DEVICE_TYPE in ["cpu", "cuda", "mps"]
15+
16+
def test_gpu_available():
17+
if config.DEVICE_TYPE == "cuda":
18+
assert config.GPU_AVAILABLE
19+
elif config.DEVICE_TYPE == "mps":
20+
assert config.GPU_AVAILABLE
21+
elif config.DEVICE_TYPE == "cpu":
22+
assert not config.GPU_AVAILABLE
23+
else:
24+
assert False
25+

0 commit comments

Comments
 (0)