Skip to content

Commit a1c67b9

Browse files
authored
Skip Unit Tests for ROCm CI (#1563)
* skip failing unit tests for ROCm CI * fix util import
1 parent d96c6a7 commit a1c67b9

16 files changed

+71
-1
lines changed

test/__init__.py

Whitespace-only changes.

test/dtypes/test_affine_quantized.py

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import unittest
33

44
import torch
5+
from test_utils import skip_if_rocm
56
from torch.testing._internal import common_utils
67
from torch.testing._internal.common_utils import (
78
TestCase,
@@ -89,6 +90,7 @@ def test_tensor_core_layout_transpose(self):
8990
aqt_shape = aqt.shape
9091
self.assertEqual(aqt_shape, shape)
9192

93+
@skip_if_rocm("ROCm development in progress")
9294
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
9395
@common_utils.parametrize(
9496
"apply_quant", get_quantization_functions(True, True, "cuda", True)
@@ -168,6 +170,7 @@ def apply_uint6_weight_only_quant(linear):
168170

169171
deregister_aqt_quantized_linear_dispatch(dispatch_condition)
170172

173+
@skip_if_rocm("ROCm development in progress")
171174
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
172175
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
173176
def test_print_quantized_module(self, apply_quant):
@@ -180,6 +183,7 @@ class TestAffineQuantizedBasic(TestCase):
180183
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
181184
COMMON_DTYPES = [torch.bfloat16]
182185

186+
@skip_if_rocm("ROCm development in progress")
183187
@common_utils.parametrize("device", COMMON_DEVICES)
184188
@common_utils.parametrize("dtype", COMMON_DTYPES)
185189
def test_flatten_unflatten(self, device, dtype):

test/dtypes/test_floatx.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import unittest
33

44
import torch
5+
from test_utils import skip_if_rocm
56
from torch.testing._internal.common_utils import (
67
TestCase,
78
instantiate_parametrized_tests,
@@ -108,6 +109,7 @@ def test_to_copy_device(self, ebits, mbits):
108109
@parametrize("ebits,mbits", _Floatx_DTYPES)
109110
@parametrize("bias", [False, True])
110111
@parametrize("dtype", [torch.half, torch.bfloat16])
112+
@skip_if_rocm("ROCm development in progress")
111113
@unittest.skipIf(is_fbcode(), reason="broken in fbcode")
112114
def test_fpx_weight_only(self, ebits, mbits, bias, dtype):
113115
N, OC, IC = 4, 256, 64

test/float8/test_base.py

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2525

2626

27+
from test_utils import skip_if_rocm
28+
2729
from torchao.float8.config import (
2830
CastConfig,
2931
Float8LinearConfig,
@@ -423,6 +425,7 @@ def test_linear_from_config_params(
423425
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
424426
@pytest.mark.parametrize("linear_bias", [True, False])
425427
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
428+
@skip_if_rocm("ROCm development in progress")
426429
def test_linear_from_recipe(
427430
self,
428431
recipe_name,

test/hqq/test_hqq_affine.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22

33
import torch
4+
from test_utils import skip_if_rocm
45

56
from torchao.quantization import (
67
MappingType,
@@ -110,6 +111,7 @@ def test_hqq_plain_5bit(self):
110111
ref_dot_product_error=0.000704,
111112
)
112113

114+
@skip_if_rocm("ROCm development in progress")
113115
def test_hqq_plain_4bit(self):
114116
self._test_hqq(
115117
dtype=torch.uint4,

test/integration/test_integration.py

+7
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@
9393
except ModuleNotFoundError:
9494
has_gemlite = False
9595

96+
from test_utils import skip_if_rocm
97+
9698
logger = logging.getLogger("INFO")
9799

98100
torch.manual_seed(0)
@@ -569,6 +571,7 @@ def test_per_token_linear_cpu(self):
569571
self._test_per_token_linear_impl("cpu", dtype)
570572

571573
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
574+
@skip_if_rocm("ROCm development in progress")
572575
def test_per_token_linear_cuda(self):
573576
for dtype in (torch.float32, torch.float16, torch.bfloat16):
574577
self._test_per_token_linear_impl("cuda", dtype)
@@ -687,6 +690,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
687690
@parameterized.expand(COMMON_DEVICE_DTYPE)
688691
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
689692
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
693+
@skip_if_rocm("ROCm development in progress")
690694
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
691695
if device == "cpu":
692696
self.skipTest(f"Temporarily skipping for {device}")
@@ -706,6 +710,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
706710
@parameterized.expand(COMMON_DEVICE_DTYPE)
707711
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
708712
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
713+
@skip_if_rocm("ROCm development in progress")
709714
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
710715
if device == "cpu":
711716
self.skipTest(f"Temporarily skipping for {device}")
@@ -899,6 +904,7 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
899904
@parameterized.expand(COMMON_DEVICE_DTYPE)
900905
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
901906
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
907+
@skip_if_rocm("ROCm development in progress")
902908
def test_int4_weight_only_quant_subclass(self, device, dtype):
903909
if device == "cpu":
904910
self.skipTest(f"Temporarily skipping for {device}")
@@ -918,6 +924,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
918924
@parameterized.expand(COMMON_DEVICE_DTYPE)
919925
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
920926
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
927+
@skip_if_rocm("ROCm development in progress")
921928
def test_int4_weight_only_quant_subclass_grouped(self, device, dtype):
922929
if dtype != torch.bfloat16:
923930
self.skipTest(f"Fails for {dtype}")

test/kernel/test_galore_downproj.py

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010
from galore_test_utils import make_data
11+
from test_utils import skip_if_rocm
1112

1213
from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk
1314
from torchao.prototype.galore.kernels.matmul import triton_mm_launcher
@@ -29,6 +30,7 @@
2930

3031
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
3132
@pytest.mark.parametrize("M, N, rank, allow_tf32, fp8_fast_accum, dtype", TEST_CONFIGS)
33+
@skip_if_rocm("ROCm development in progress")
3234
def test_galore_downproj(M, N, rank, allow_tf32, fp8_fast_accum, dtype):
3335
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
3436
MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32

test/prototype/test_awq.py

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
if TORCH_VERSION_AT_LEAST_2_3:
1111
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_
1212

13+
from test_utils import skip_if_rocm
14+
1315

1416
class ToyLinearModel(torch.nn.Module):
1517
def __init__(self, m=512, n=256, k=128):
@@ -113,6 +115,7 @@ def test_awq_loading(device, qdtype):
113115

114116
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
115117
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
118+
@skip_if_rocm("ROCm development in progress")
116119
def test_save_weights_only():
117120
dataset_size = 100
118121
l1, l2, l3 = 512, 256, 128

test/prototype/test_low_bit_optim.py

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
except ImportError:
4343
lpmm = None
4444

45+
from test_utils import skip_if_rocm
4546

4647
_DEVICES = get_available_devices()
4748

@@ -112,6 +113,7 @@ class TestOptim(TestCase):
112113
)
113114
@parametrize("dtype", [torch.float32, torch.bfloat16])
114115
@parametrize("device", _DEVICES)
116+
@skip_if_rocm("ROCm development in progress")
115117
def test_optim_smoke(self, optim_name, dtype, device):
116118
if optim_name.endswith("Fp8") and device == "cuda":
117119
if not TORCH_VERSION_AT_LEAST_2_4:

test/prototype/test_splitk.py

+3
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
except ImportError:
1414
triton_available = False
1515

16+
from test_utils import skip_if_rocm
17+
1618
from torchao.utils import skip_if_compute_capability_less_than
1719

1820

1921
@unittest.skipIf(not triton_available, "Triton is required but not available")
2022
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
2123
class TestFP8Gemm(TestCase):
2224
@skip_if_compute_capability_less_than(9.0)
25+
@skip_if_rocm("ROCm development in progress")
2326
def test_gemm_split_k(self):
2427
dtype = torch.float16
2528
qdtype = torch.float8_e4m3fn

test/quantization/test_galore_quant.py

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
dequantize_blockwise,
1414
quantize_blockwise,
1515
)
16+
from test_utils import skip_if_rocm
1617

1718
from torchao.prototype.galore.kernels import (
1819
triton_dequant_blockwise,
@@ -82,6 +83,7 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize):
8283
"dim1,dim2,dtype,signed,blocksize",
8384
TEST_CONFIGS,
8485
)
86+
@skip_if_rocm("ROCm development in progress")
8587
def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize):
8688
g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01
8789

test/quantization/test_marlin_qqq.py

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
import torch
6+
from test_utils import skip_if_rocm
67
from torch import nn
78
from torch.testing._internal.common_utils import TestCase, run_tests
89

@@ -45,6 +46,7 @@ def setUp(self):
4546
)
4647

4748
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
49+
@skip_if_rocm("ROCm development in progress")
4850
def test_marlin_qqq(self):
4951
output_ref = self.model(self.input)
5052
for group_size in [-1, 128]:
@@ -66,6 +68,7 @@ def test_marlin_qqq(self):
6668

6769
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
6870
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
71+
@skip_if_rocm("ROCm development in progress")
6972
def test_marlin_qqq_compile(self):
7073
model_copy = copy.deepcopy(self.model)
7174
model_copy.forward = torch.compile(model_copy.forward, fullgraph=True)

test/sparsity/test_marlin.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pytest
44
import torch
5+
from test_utils import skip_if_rocm
56
from torch import nn
67
from torch.testing._internal.common_utils import TestCase, run_tests
78

@@ -37,6 +38,7 @@ def setUp(self):
3738
)
3839

3940
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
41+
@skip_if_rocm("ROCm development in progress")
4042
def test_quant_sparse_marlin_layout_eager(self):
4143
apply_fake_sparsity(self.model)
4244
model_copy = copy.deepcopy(self.model)
@@ -48,13 +50,13 @@ def test_quant_sparse_marlin_layout_eager(self):
4850
# Sparse + quantized
4951
quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout()))
5052
sparse_result = self.model(self.input)
51-
5253
assert torch.allclose(
5354
dense_result, sparse_result, atol=3e-1
5455
), "Results are not close"
5556

5657
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
5758
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
59+
@skip_if_rocm("ROCm development in progress")
5860
def test_quant_sparse_marlin_layout_compile(self):
5961
apply_fake_sparsity(self.model)
6062
model_copy = copy.deepcopy(self.model)

test/test_ops.py

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24
2020
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff, is_fbcode
2121

22+
if torch.version.hip is not None:
23+
pytest.skip("Skipping the test in ROCm", allow_module_level=True)
24+
2225
if is_fbcode():
2326
pytest.skip(
2427
"Skipping the test in fbcode since we don't have TARGET file for kernels"

test/test_s8s4_linear_cutlass.py

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from torchao.quantization.utils import group_quantize_tensor_symmetric
88
from torchao.utils import compute_max_diff
99

10+
if torch.version.hip is not None:
11+
pytest.skip("Skipping the test in ROCm", allow_module_level=True)
12+
1013
S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
1114
S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
1215
S8S4_LINEAR_CUTLASS_SIZE_MNK = [

test/test_utils.py

+29
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,40 @@
1+
import functools
12
import unittest
23
from unittest.mock import patch
34

5+
import pytest
46
import torch
57

68
from torchao.utils import TorchAOBaseTensor, torch_version_at_least
79

810

11+
def skip_if_rocm(message=None):
12+
"""Decorator to skip tests on ROCm platform with custom message.
13+
14+
Args:
15+
message (str, optional): Additional information about why the test is skipped.
16+
"""
17+
18+
def decorator(func):
19+
@functools.wraps(func)
20+
def wrapper(*args, **kwargs):
21+
if torch.version.hip is not None:
22+
skip_message = "Skipping the test in ROCm"
23+
if message:
24+
skip_message += f": {message}"
25+
pytest.skip(skip_message)
26+
return func(*args, **kwargs)
27+
28+
return wrapper
29+
30+
# Handle both @skip_if_rocm and @skip_if_rocm() syntax
31+
if callable(message):
32+
func = message
33+
message = None
34+
return decorator(func)
35+
return decorator
36+
37+
938
class TestTorchVersionAtLeast(unittest.TestCase):
1039
def test_torch_version_at_least(self):
1140
test_cases = [

0 commit comments

Comments
 (0)