Skip to content

Commit 9afaabb

Browse files
authored
Revert "Skip Unit Tests for ROCm CI" (#1580)
Revert "Skip Unit Tests for ROCm CI (#1563)" This reverts commit a1c67b9.
1 parent 69f3795 commit 9afaabb

16 files changed

+1
-71
lines changed

Diff for: test/__init__.py

Whitespace-only changes.

Diff for: test/dtypes/test_affine_quantized.py

-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import unittest
33

44
import torch
5-
from test_utils import skip_if_rocm
65
from torch.testing._internal import common_utils
76
from torch.testing._internal.common_utils import (
87
TestCase,
@@ -90,7 +89,6 @@ def test_tensor_core_layout_transpose(self):
9089
aqt_shape = aqt.shape
9190
self.assertEqual(aqt_shape, shape)
9291

93-
@skip_if_rocm("ROCm development in progress")
9492
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
9593
@common_utils.parametrize(
9694
"apply_quant", get_quantization_functions(True, True, "cuda", True)
@@ -170,7 +168,6 @@ def apply_uint6_weight_only_quant(linear):
170168

171169
deregister_aqt_quantized_linear_dispatch(dispatch_condition)
172170

173-
@skip_if_rocm("ROCm development in progress")
174171
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
175172
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
176173
def test_print_quantized_module(self, apply_quant):
@@ -183,7 +180,6 @@ class TestAffineQuantizedBasic(TestCase):
183180
COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
184181
COMMON_DTYPES = [torch.bfloat16]
185182

186-
@skip_if_rocm("ROCm development in progress")
187183
@common_utils.parametrize("device", COMMON_DEVICES)
188184
@common_utils.parametrize("dtype", COMMON_DTYPES)
189185
def test_flatten_unflatten(self, device, dtype):

Diff for: test/dtypes/test_floatx.py

-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import unittest
33

44
import torch
5-
from test_utils import skip_if_rocm
65
from torch.testing._internal.common_utils import (
76
TestCase,
87
instantiate_parametrized_tests,
@@ -109,7 +108,6 @@ def test_to_copy_device(self, ebits, mbits):
109108
@parametrize("ebits,mbits", _Floatx_DTYPES)
110109
@parametrize("bias", [False, True])
111110
@parametrize("dtype", [torch.half, torch.bfloat16])
112-
@skip_if_rocm("ROCm development in progress")
113111
@unittest.skipIf(is_fbcode(), reason="broken in fbcode")
114112
def test_fpx_weight_only(self, ebits, mbits, bias, dtype):
115113
N, OC, IC = 4, 256, 64

Diff for: test/float8/test_base.py

-3
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2525

2626

27-
from test_utils import skip_if_rocm
28-
2927
from torchao.float8.config import (
3028
CastConfig,
3129
Float8LinearConfig,
@@ -425,7 +423,6 @@ def test_linear_from_config_params(
425423
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
426424
@pytest.mark.parametrize("linear_bias", [True, False])
427425
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
428-
@skip_if_rocm("ROCm development in progress")
429426
def test_linear_from_recipe(
430427
self,
431428
recipe_name,

Diff for: test/hqq/test_hqq_affine.py

-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import unittest
22

33
import torch
4-
from test_utils import skip_if_rocm
54

65
from torchao.quantization import (
76
MappingType,
@@ -111,7 +110,6 @@ def test_hqq_plain_5bit(self):
111110
ref_dot_product_error=0.000704,
112111
)
113112

114-
@skip_if_rocm("ROCm development in progress")
115113
def test_hqq_plain_4bit(self):
116114
self._test_hqq(
117115
dtype=torch.uint4,

Diff for: test/integration/test_integration.py

-7
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,6 @@
9393
except ModuleNotFoundError:
9494
has_gemlite = False
9595

96-
from test_utils import skip_if_rocm
97-
9896
logger = logging.getLogger("INFO")
9997

10098
torch.manual_seed(0)
@@ -571,7 +569,6 @@ def test_per_token_linear_cpu(self):
571569
self._test_per_token_linear_impl("cpu", dtype)
572570

573571
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
574-
@skip_if_rocm("ROCm development in progress")
575572
def test_per_token_linear_cuda(self):
576573
for dtype in (torch.float32, torch.float16, torch.bfloat16):
577574
self._test_per_token_linear_impl("cuda", dtype)
@@ -690,7 +687,6 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
690687
@parameterized.expand(COMMON_DEVICE_DTYPE)
691688
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
692689
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
693-
@skip_if_rocm("ROCm development in progress")
694690
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
695691
if device == "cpu":
696692
self.skipTest(f"Temporarily skipping for {device}")
@@ -710,7 +706,6 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
710706
@parameterized.expand(COMMON_DEVICE_DTYPE)
711707
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
712708
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
713-
@skip_if_rocm("ROCm development in progress")
714709
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
715710
if device == "cpu":
716711
self.skipTest(f"Temporarily skipping for {device}")
@@ -904,7 +899,6 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
904899
@parameterized.expand(COMMON_DEVICE_DTYPE)
905900
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
906901
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
907-
@skip_if_rocm("ROCm development in progress")
908902
def test_int4_weight_only_quant_subclass(self, device, dtype):
909903
if device == "cpu":
910904
self.skipTest(f"Temporarily skipping for {device}")
@@ -924,7 +918,6 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
924918
@parameterized.expand(COMMON_DEVICE_DTYPE)
925919
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
926920
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
927-
@skip_if_rocm("ROCm development in progress")
928921
def test_int4_weight_only_quant_subclass_grouped(self, device, dtype):
929922
if dtype != torch.bfloat16:
930923
self.skipTest(f"Fails for {dtype}")

Diff for: test/kernel/test_galore_downproj.py

-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import torch
1010
from galore_test_utils import make_data
11-
from test_utils import skip_if_rocm
1211

1312
from torchao.prototype.galore.kernels.matmul import set_tuner_top_k as matmul_tuner_topk
1413
from torchao.prototype.galore.kernels.matmul import triton_mm_launcher
@@ -30,7 +29,6 @@
3029

3130
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
3231
@pytest.mark.parametrize("M, N, rank, allow_tf32, fp8_fast_accum, dtype", TEST_CONFIGS)
33-
@skip_if_rocm("ROCm development in progress")
3432
def test_galore_downproj(M, N, rank, allow_tf32, fp8_fast_accum, dtype):
3533
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
3634
MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32

Diff for: test/prototype/test_awq.py

-3
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
if TORCH_VERSION_AT_LEAST_2_3:
1111
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_
1212

13-
from test_utils import skip_if_rocm
14-
1513

1614
class ToyLinearModel(torch.nn.Module):
1715
def __init__(self, m=512, n=256, k=128):
@@ -115,7 +113,6 @@ def test_awq_loading(device, qdtype):
115113

116114
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
117115
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
118-
@skip_if_rocm("ROCm development in progress")
119116
def test_save_weights_only():
120117
dataset_size = 100
121118
l1, l2, l3 = 512, 256, 128

Diff for: test/prototype/test_low_bit_optim.py

-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
except ImportError:
4343
lpmm = None
4444

45-
from test_utils import skip_if_rocm
4645

4746
_DEVICES = get_available_devices()
4847

@@ -113,7 +112,6 @@ class TestOptim(TestCase):
113112
)
114113
@parametrize("dtype", [torch.float32, torch.bfloat16])
115114
@parametrize("device", _DEVICES)
116-
@skip_if_rocm("ROCm development in progress")
117115
def test_optim_smoke(self, optim_name, dtype, device):
118116
if optim_name.endswith("Fp8") and device == "cuda":
119117
if not TORCH_VERSION_AT_LEAST_2_4:

Diff for: test/prototype/test_splitk.py

-3
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,13 @@
1313
except ImportError:
1414
triton_available = False
1515

16-
from test_utils import skip_if_rocm
17-
1816
from torchao.utils import skip_if_compute_capability_less_than
1917

2018

2119
@unittest.skipIf(not triton_available, "Triton is required but not available")
2220
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
2321
class TestFP8Gemm(TestCase):
2422
@skip_if_compute_capability_less_than(9.0)
25-
@skip_if_rocm("ROCm development in progress")
2623
def test_gemm_split_k(self):
2724
dtype = torch.float16
2825
qdtype = torch.float8_e4m3fn

Diff for: test/quantization/test_galore_quant.py

-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
dequantize_blockwise,
1414
quantize_blockwise,
1515
)
16-
from test_utils import skip_if_rocm
1716

1817
from torchao.prototype.galore.kernels import (
1918
triton_dequant_blockwise,
@@ -83,7 +82,6 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize):
8382
"dim1,dim2,dtype,signed,blocksize",
8483
TEST_CONFIGS,
8584
)
86-
@skip_if_rocm("ROCm development in progress")
8785
def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize):
8886
g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01
8987

Diff for: test/quantization/test_marlin_qqq.py

-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import pytest
55
import torch
6-
from test_utils import skip_if_rocm
76
from torch import nn
87
from torch.testing._internal.common_utils import TestCase, run_tests
98

@@ -46,7 +45,6 @@ def setUp(self):
4645
)
4746

4847
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
49-
@skip_if_rocm("ROCm development in progress")
5048
def test_marlin_qqq(self):
5149
output_ref = self.model(self.input)
5250
for group_size in [-1, 128]:
@@ -68,7 +66,6 @@ def test_marlin_qqq(self):
6866

6967
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
7068
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
71-
@skip_if_rocm("ROCm development in progress")
7269
def test_marlin_qqq_compile(self):
7370
model_copy = copy.deepcopy(self.model)
7471
model_copy.forward = torch.compile(model_copy.forward, fullgraph=True)

Diff for: test/sparsity/test_marlin.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import pytest
44
import torch
5-
from test_utils import skip_if_rocm
65
from torch import nn
76
from torch.testing._internal.common_utils import TestCase, run_tests
87

@@ -38,7 +37,6 @@ def setUp(self):
3837
)
3938

4039
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
41-
@skip_if_rocm("ROCm development in progress")
4240
def test_quant_sparse_marlin_layout_eager(self):
4341
apply_fake_sparsity(self.model)
4442
model_copy = copy.deepcopy(self.model)
@@ -50,13 +48,13 @@ def test_quant_sparse_marlin_layout_eager(self):
5048
# Sparse + quantized
5149
quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout()))
5250
sparse_result = self.model(self.input)
51+
5352
assert torch.allclose(
5453
dense_result, sparse_result, atol=3e-1
5554
), "Results are not close"
5655

5756
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
5857
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
59-
@skip_if_rocm("ROCm development in progress")
6058
def test_quant_sparse_marlin_layout_compile(self):
6159
apply_fake_sparsity(self.model)
6260
model_copy = copy.deepcopy(self.model)

Diff for: test/test_ops.py

-3
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@
1919
from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24
2020
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff, is_fbcode
2121

22-
if torch.version.hip is not None:
23-
pytest.skip("Skipping the test in ROCm", allow_module_level=True)
24-
2522
if is_fbcode():
2623
pytest.skip(
2724
"Skipping the test in fbcode since we don't have TARGET file for kernels"

Diff for: test/test_s8s4_linear_cutlass.py

-3
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
from torchao.quantization.utils import group_quantize_tensor_symmetric
88
from torchao.utils import compute_max_diff
99

10-
if torch.version.hip is not None:
11-
pytest.skip("Skipping the test in ROCm", allow_module_level=True)
12-
1310
S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
1411
S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
1512
S8S4_LINEAR_CUTLASS_SIZE_MNK = [

Diff for: test/test_utils.py

-29
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,11 @@
1-
import functools
21
import unittest
32
from unittest.mock import patch
43

5-
import pytest
64
import torch
75

86
from torchao.utils import TorchAOBaseTensor, torch_version_at_least
97

108

11-
def skip_if_rocm(message=None):
12-
"""Decorator to skip tests on ROCm platform with custom message.
13-
14-
Args:
15-
message (str, optional): Additional information about why the test is skipped.
16-
"""
17-
18-
def decorator(func):
19-
@functools.wraps(func)
20-
def wrapper(*args, **kwargs):
21-
if torch.version.hip is not None:
22-
skip_message = "Skipping the test in ROCm"
23-
if message:
24-
skip_message += f": {message}"
25-
pytest.skip(skip_message)
26-
return func(*args, **kwargs)
27-
28-
return wrapper
29-
30-
# Handle both @skip_if_rocm and @skip_if_rocm() syntax
31-
if callable(message):
32-
func = message
33-
message = None
34-
return decorator(func)
35-
return decorator
36-
37-
389
class TestTorchVersionAtLeast(unittest.TestCase):
3910
def test_torch_version_at_least(self):
4011
test_cases = [

0 commit comments

Comments
 (0)