Skip to content

Commit d491087

Browse files
committed
update skip_if_rocm import
lint
1 parent f52d14a commit d491087

14 files changed

+44
-48
lines changed

test/dtypes/test_affine_quantized.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import unittest
33

44
import torch
5-
from test_utils import skip_if_rocm
65
from torch.testing._internal import common_utils
76
from torch.testing._internal.common_utils import (
87
TestCase,
@@ -22,6 +21,7 @@
2221
TORCH_VERSION_AT_LEAST_2_5,
2322
TORCH_VERSION_AT_LEAST_2_6,
2423
is_sm_at_least_89,
24+
skip_if_rocm,
2525
)
2626

2727

test/dtypes/test_floatx.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import unittest
33

44
import torch
5-
from test_utils import skip_if_rocm
65
from torch.testing._internal.common_utils import (
76
TestCase,
87
instantiate_parametrized_tests,
@@ -28,7 +27,7 @@
2827
fpx_weight_only,
2928
quantize_,
3029
)
31-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode
30+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode, skip_if_rocm
3231

3332
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
3433
_Floatx_DTYPES = [(3, 2), (2, 2)]

test/float8/test_base.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@
1818
TORCH_VERSION_AT_LEAST_2_5,
1919
is_sm_at_least_89,
2020
is_sm_at_least_90,
21+
skip_if_rocm,
2122
)
2223

2324
if not TORCH_VERSION_AT_LEAST_2_5:
2425
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2526

2627

27-
from test_utils import skip_if_rocm
28-
2928
from torchao.float8.config import (
3029
CastConfig,
3130
Float8LinearConfig,

test/hqq/test_hqq_affine.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import unittest
22

33
import torch
4-
from test_utils import skip_if_rocm
54

65
from torchao.quantization import (
76
MappingType,
@@ -11,6 +10,7 @@
1110
)
1211
from torchao.utils import (
1312
TORCH_VERSION_AT_LEAST_2_3,
13+
skip_if_rocm,
1414
)
1515

1616
cuda_available = torch.cuda.is_available()

test/integration/test_integration.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
benchmark_model,
8181
is_fbcode,
8282
is_sm_at_least_90,
83+
skip_if_rocm,
8384
unwrap_tensor_subclass,
8485
)
8586

@@ -90,7 +91,6 @@
9091
except ModuleNotFoundError:
9192
has_gemlite = False
9293

93-
from test_utils import skip_if_rocm
9494

9595
logger = logging.getLogger("INFO")
9696

test/kernel/test_galore_downproj.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
import torch
1010
from galore_test_utils import make_data
11-
from test_utils import skip_if_rocm
1211

1312
from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk
1413
from torchao.prototype.galore.kernels.matmul import triton_mm_launcher
14+
from torchao.utils import skip_if_rocm
1515

1616
torch.manual_seed(0)
1717

test/prototype/test_awq.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
import torch
66

77
from torchao.quantization import quantize_
8-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5
8+
from torchao.utils import (
9+
TORCH_VERSION_AT_LEAST_2_3,
10+
TORCH_VERSION_AT_LEAST_2_5,
11+
skip_if_rocm,
12+
)
913

1014
if TORCH_VERSION_AT_LEAST_2_3:
1115
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_
1216

13-
from test_utils import skip_if_rocm
14-
1517

1618
class ToyLinearModel(torch.nn.Module):
1719
def __init__(self, m=512, n=256, k=128):

test/prototype/test_low_bit_optim.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
TORCH_VERSION_AT_LEAST_2_4,
3131
TORCH_VERSION_AT_LEAST_2_5,
3232
get_available_devices,
33+
skip_if_rocm,
3334
)
3435

3536
try:
@@ -42,7 +43,6 @@
4243
except ImportError:
4344
lpmm = None
4445

45-
from test_utils import skip_if_rocm
4646

4747
_DEVICES = get_available_devices()
4848

test/prototype/test_splitk.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
except ImportError:
1414
triton_available = False
1515

16-
from test_utils import skip_if_rocm
1716

18-
from torchao.utils import skip_if_compute_capability_less_than
17+
from torchao.utils import skip_if_compute_capability_less_than, skip_if_rocm
1918

2019

2120
@unittest.skipIf(not triton_available, "Triton is required but not available")

test/quantization/test_galore_quant.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
dequantize_blockwise,
1414
quantize_blockwise,
1515
)
16-
from test_utils import skip_if_rocm
1716

1817
from torchao.prototype.galore.kernels import (
1918
triton_dequant_blockwise,
2019
triton_quantize_blockwise,
2120
)
21+
from torchao.utils import skip_if_rocm
2222

2323
SEED = 0
2424
torch.manual_seed(SEED)

test/quantization/test_marlin_qqq.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import pytest
55
import torch
6-
from test_utils import skip_if_rocm
76
from torch import nn
87
from torch.testing._internal.common_utils import TestCase, run_tests
98

@@ -20,7 +19,7 @@
2019
MappingType,
2120
choose_qparams_and_quantize_affine_qqq,
2221
)
23-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode
22+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_fbcode, skip_if_rocm
2423

2524

2625
@unittest.skipIf(

test/sparsity/test_marlin.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import pytest
44
import torch
5-
from test_utils import skip_if_rocm
65
from torch import nn
76
from torch.testing._internal.common_utils import TestCase, run_tests
87

@@ -16,7 +15,7 @@
1615
)
1716
from torchao.sparsity.marlin import inject_24, pack_to_marlin_24, unpack_from_marlin_24
1817
from torchao.sparsity.sparse_api import apply_fake_sparsity
19-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
18+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm
2019

2120

2221
class SparseMarlin24(TestCase):

test/test_utils.py

-29
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,11 @@
1-
import functools
21
import unittest
32
from unittest.mock import patch
43

5-
import pytest
64
import torch
75

86
from torchao.utils import TorchAOBaseTensor, torch_version_at_least
97

108

11-
def skip_if_rocm(message=None):
12-
"""Decorator to skip tests on ROCm platform with custom message.
13-
14-
Args:
15-
message (str, optional): Additional information about why the test is skipped.
16-
"""
17-
18-
def decorator(func):
19-
@functools.wraps(func)
20-
def wrapper(*args, **kwargs):
21-
if torch.version.hip is not None:
22-
skip_message = "Skipping the test in ROCm"
23-
if message:
24-
skip_message += f": {message}"
25-
pytest.skip(skip_message)
26-
return func(*args, **kwargs)
27-
28-
return wrapper
29-
30-
# Handle both @skip_if_rocm and @skip_if_rocm() syntax
31-
if callable(message):
32-
func = message
33-
message = None
34-
return decorator(func)
35-
return decorator
36-
37-
389
class TestTorchVersionAtLeast(unittest.TestCase):
3910
def test_torch_version_at_least(self):
4011
test_cases = [

torchao/utils.py

+28
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from math import gcd
88
from typing import Any, Callable, Tuple
99

10+
import pytest
1011
import torch
1112
import torch.nn.utils.parametrize as parametrize
1213

@@ -161,6 +162,33 @@ def wrapper(*args, **kwargs):
161162
return decorator
162163

163164

165+
def skip_if_rocm(message=None):
166+
"""Decorator to skip tests on ROCm platform with custom message.
167+
168+
Args:
169+
message (str, optional): Additional information about why the test is skipped.
170+
"""
171+
172+
def decorator(func):
173+
@functools.wraps(func)
174+
def wrapper(*args, **kwargs):
175+
if torch.version.hip is not None:
176+
skip_message = "Skipping the test in ROCm"
177+
if message:
178+
skip_message += f": {message}"
179+
pytest.skip(skip_message)
180+
return func(*args, **kwargs)
181+
182+
return wrapper
183+
184+
# Handle both @skip_if_rocm and @skip_if_rocm() syntax
185+
if callable(message):
186+
func = message
187+
message = None
188+
return decorator(func)
189+
return decorator
190+
191+
164192
def compute_max_diff(output: torch.Tensor, output_ref: torch.Tensor) -> torch.Tensor:
165193
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
166194
torch.abs(output_ref)

0 commit comments

Comments
 (0)