Skip to content

Commit e576351

Browse files
authored
Merge pull request #2 from thunderock/add_gpu_tests
Add gpu tests
2 parents d9017b3 + e6cce98 commit e576351

File tree

6 files changed

+54
-1
lines changed

6 files changed

+54
-1
lines changed

.pre-commit-config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ repos:
1818
rev: v0.0.291
1919
hooks:
2020
- id: ruff
21+
args: [ --fix ]

Makefile

+8-1
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,11 @@ format:
3030
@echo "formatting..."
3131
@poetry install --only lint
3232
@poetry run black .
33-
@poetry run pre-commit run --all-files
33+
# ruff fix
34+
@poetry run pre-commit run --all-files --config .pre-commit-config.yaml
35+
36+
.PHONY: run_tests
37+
run_tests:
38+
@echo "running tests..."
39+
@poetry install --only main --only test -vvv
40+
@poetry run pytest -q tests

graph_ml/utils/config.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import os
2+
import torch
3+
4+
5+
OS = os.name
6+
DEVICE_TYPE = "cpu"
7+
if OS == "posix":
8+
if torch.cuda.is_available():
9+
DEVICE_TYPE = "cuda"
10+
elif torch.backends.mps.is_available():
11+
DEVICE_TYPE = "mps"
12+
13+
DEVICE = torch.device(DEVICE_TYPE)
14+
15+
GPU_AVAILABLE = DEVICE_TYPE in ["cuda", "mps"]

graph_ml/utils/gpu_utils.py

Whitespace-only changes.

pyproject.toml

+5
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,8 @@ pytest = "^8.0.2"
6464
[tool.poetry.group.lint.dependencies]
6565
pre-commit = "^3.6.2"
6666
black = "^24.2.0"
67+
68+
69+
70+
# poetry add git+ssh://[email protected]/thunderock/graph_ml.git
71+
# poetry add whl_url

tests/test_gpu_available.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from __future__ import annotations
2+
import os
3+
4+
from graph_ml.utils import config
5+
6+
OS = os.name
7+
8+
9+
def test_target_os():
10+
assert OS == "posix"
11+
12+
13+
def test_device_type():
14+
assert config.DEVICE_TYPE in ["cpu", "cuda", "mps"]
15+
16+
17+
def test_gpu_available():
18+
if config.DEVICE_TYPE == "cuda":
19+
assert config.GPU_AVAILABLE
20+
elif config.DEVICE_TYPE == "mps":
21+
assert config.GPU_AVAILABLE
22+
elif config.DEVICE_TYPE == "cpu":
23+
assert not config.GPU_AVAILABLE
24+
else:
25+
assert False

0 commit comments

Comments
 (0)