Skip to content

Commit cc3f9d9

Browse files
committed
skip ROCm tests
1 parent cd5b4e4 commit cc3f9d9

13 files changed

+20
-2
lines changed

test/dtypes/test_affine_quantized.py

+3
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def test_tensor_core_layout_transpose(self):
9494
@common_utils.parametrize(
9595
"apply_quant", get_quantization_functions(True, True, "cuda", True)
9696
)
97+
@skip_if_rocm("ROCm enablement in progress")
9798
def test_weights_only(self, apply_quant):
9899
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
99100
ql = apply_quant(linear)
@@ -171,6 +172,7 @@ def apply_uint6_weight_only_quant(linear):
171172

172173
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
173174
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
175+
@skip_if_rocm("ROCm enablement in progress")
174176
def test_print_quantized_module(self, apply_quant):
175177
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
176178
ql = apply_quant(linear)
@@ -183,6 +185,7 @@ class TestAffineQuantizedBasic(TestCase):
183185

184186
@common_utils.parametrize("device", COMMON_DEVICES)
185187
@common_utils.parametrize("dtype", COMMON_DTYPES)
188+
@skip_if_rocm("ROCm enablement in progress")
186189
def test_flatten_unflatten(self, device, dtype):
187190
apply_quant_list = get_quantization_functions(False, True, device)
188191
for apply_quant in apply_quant_list:

test/dtypes/test_floatx.py

+1
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def test_to_copy_device(self, ebits, mbits):
109109
@parametrize("bias", [False, True])
110110
@parametrize("dtype", [torch.half, torch.bfloat16])
111111
@unittest.skipIf(is_fbcode(), reason="broken in fbcode")
112+
@skip_if_rocm("ROCm enablement in progress")
112113
def test_fpx_weight_only(self, ebits, mbits, bias, dtype):
113114
N, OC, IC = 4, 256, 64
114115
device = "cuda"

test/dtypes/test_uint4.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from torchao.quantization.quant_api import (
2929
_replace_with_custom_fn_if_matches_filter,
3030
)
31-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
31+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, skip_if_rocm
3232

3333

3434
def _apply_weight_only_uint4_quant(model):
@@ -92,6 +92,7 @@ def test_basic_tensor_ops(self):
9292
# only test locally
9393
# print("x:", x[0])
9494

95+
@skip_if_rocm("ROCm enablement in progress")
9596
def test_gpu_quant(self):
9697
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
9798
x = torch.randn(*x_shape)
@@ -104,6 +105,7 @@ def test_gpu_quant(self):
104105
# make sure it runs
105106
opt(x)
106107

108+
@skip_if_rocm("ROCm enablement in progress")
107109
def test_pt2e_quant(self):
108110
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
109111
QuantizationConfig,

test/float8/test_base.py

+1
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ def test_linear_from_config_params(
424424
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
425425
@pytest.mark.parametrize("linear_bias", [True, False])
426426
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
427+
@skip_if_rocm("ROCm enablement in progress")
427428
def test_linear_from_recipe(
428429
self,
429430
recipe_name,

test/hqq/test_hqq_affine.py

+1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def test_hqq_plain_5bit(self):
111111
ref_dot_product_error=0.000704,
112112
)
113113

114+
@skip_if_rocm("ROCm enablement in progress")
114115
def test_hqq_plain_4bit(self):
115116
self._test_hqq(
116117
dtype=torch.uint4,

test/integration/test_integration.py

+3
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,7 @@ def test_per_token_linear_cpu(self):
570570
self._test_per_token_linear_impl("cpu", dtype)
571571

572572
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
573+
@skip_if_rocm("ROCm enablement in progress")
573574
def test_per_token_linear_cuda(self):
574575
for dtype in (torch.float32, torch.float16, torch.bfloat16):
575576
self._test_per_token_linear_impl("cuda", dtype)
@@ -688,6 +689,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
688689
@parameterized.expand(COMMON_DEVICE_DTYPE)
689690
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
690691
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
692+
@skip_if_rocm("ROCm enablement in progress")
691693
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
692694
if device == "cpu":
693695
self.skipTest(f"Temporarily skipping for {device}")
@@ -707,6 +709,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
707709
@parameterized.expand(COMMON_DEVICE_DTYPE)
708710
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
709711
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
712+
@skip_if_rocm("ROCm enablement in progress")
710713
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
711714
if device == "cpu":
712715
self.skipTest(f"Temporarily skipping for {device}")

test/kernel/test_galore_downproj.py

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
3232
@pytest.mark.parametrize("M, N, rank, allow_tf32, fp8_fast_accum, dtype", TEST_CONFIGS)
33+
@skip_if_rocm("ROCm enablement in progress")
3334
def test_galore_downproj(M, N, rank, allow_tf32, fp8_fast_accum, dtype):
3435
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
3536
MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32

test/prototype/test_awq.py

+1
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def test_awq_loading(device, qdtype):
117117

118118
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="requires nightly pytorch")
119119
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
120+
@skip_if_rocm("ROCm enablement in progress")
120121
def test_save_weights_only():
121122
dataset_size = 100
122123
l1, l2, l3 = 512, 256, 128

test/prototype/test_low_bit_optim.py

+1
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class TestOptim(TestCase):
113113
)
114114
@parametrize("dtype", [torch.float32, torch.bfloat16])
115115
@parametrize("device", _DEVICES)
116+
@skip_if_rocm("ROCm enablement in progress")
116117
def test_optim_smoke(self, optim_name, dtype, device):
117118
if optim_name.endswith("Fp8") and device == "cuda":
118119
if not TORCH_VERSION_AT_LEAST_2_4:

test/prototype/test_splitk.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from torchao.utils import skip_if_compute_capability_less_than, skip_if_rocm
1818

1919

20-
2120
@unittest.skipIf(not triton_available, "Triton is required but not available")
2221
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
2322
class TestFP8Gemm(TestCase):
2423
@skip_if_compute_capability_less_than(9.0)
24+
@skip_if_rocm("ROCm enablement in progress")
2525
def test_gemm_split_k(self):
2626
dtype = torch.float16
2727
qdtype = torch.float8_e4m3fn

test/quantization/test_galore_quant.py

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize):
8383
"dim1,dim2,dtype,signed,blocksize",
8484
TEST_CONFIGS,
8585
)
86+
@skip_if_rocm("ROCm enablement in progress")
8687
def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize):
8788
g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01
8889

test/quantization/test_marlin_qqq.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
is_fbcode(),
2727
"Skipping the test in fbcode since we don't have TARGET file for kernels",
2828
)
29+
@skip_if_rocm("ROCm enablement in progress")
2930
class TestMarlinQQQ(TestCase):
3031
def setUp(self):
3132
super().setUp()

test/sparsity/test_marlin.py

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def setUp(self):
3737
)
3838

3939
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
40+
@skip_if_rocm("ROCm enablement in progress")
4041
def test_quant_sparse_marlin_layout_eager(self):
4142
apply_fake_sparsity(self.model)
4243
model_copy = copy.deepcopy(self.model)
@@ -55,6 +56,7 @@ def test_quant_sparse_marlin_layout_eager(self):
5556

5657
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
5758
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
59+
@skip_if_rocm("ROCm enablement in progress")
5860
def test_quant_sparse_marlin_layout_compile(self):
5961
apply_fake_sparsity(self.model)
6062
model_copy = copy.deepcopy(self.model)

0 commit comments

Comments
 (0)