Skip to content

Commit f22207b

Browse files
committed
skip failing MX tests on cuda 10.0
Summary: PyTorch's Triton version does not yet work on cuda 10.0, skipping relevant tests from MX folder for now. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 80c4e04efaa75d94df6d17a3d8b5ff45788c0179 ghstack-comment-id: 2614583779 Pull Request resolved: #1624
1 parent 47f96f1 commit f22207b

File tree

4 files changed

+37
-3
lines changed

4 files changed

+37
-3
lines changed

test/prototype/mx_formats/test_custom_cast.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
sem_vals_to_f32,
4141
)
4242
from torchao.prototype.mx_formats.mx_tensor import MXTensor
43-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
43+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100
4444

4545
torch.manual_seed(0)
4646

@@ -310,6 +310,9 @@ def test_fp4_pack_unpack():
310310
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
311311
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
312312
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
313+
@pytest.mark.skipif(
314+
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
315+
)
313316
def test_fp4_triton_unscaled_cast():
314317
packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda")
315318
f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals))
@@ -320,6 +323,9 @@ def test_fp4_triton_unscaled_cast():
320323
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
321324
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
322325
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
326+
@pytest.mark.skipif(
327+
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
328+
)
323329
def test_fp4_triton_scaled_cast():
324330
size = (256,)
325331
orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100

test/prototype/mx_formats/test_mx_linear.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
swap_linear_with_mx_linear,
1919
)
2020
from torchao.quantization.utils import compute_error
21-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89
21+
from torchao.utils import (
22+
TORCH_VERSION_AT_LEAST_2_4,
23+
is_sm_at_least_89,
24+
is_sm_at_least_100,
25+
)
2226

2327
torch.manual_seed(2)
2428

@@ -99,6 +103,9 @@ def test_activation_checkpointing():
99103

100104

101105
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
106+
@pytest.mark.skipif(
107+
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
108+
)
102109
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
103110
@pytest.mark.parametrize("bias", [False, True])
104111
# TODO(future PR): figure out why torch.compile does not match eager when
@@ -184,6 +191,9 @@ def test_inference_linear(elem_dtype, bias, input_shape):
184191

185192

186193
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
194+
@pytest.mark.skipif(
195+
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
196+
)
187197
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
188198
def test_inference_compile_simple(elem_dtype):
189199
"""

test/prototype/mx_formats/test_mx_tensor.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
to_dtype,
2222
)
2323
from torchao.quantization.utils import compute_error
24-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_89
24+
from torchao.utils import (
25+
TORCH_VERSION_AT_LEAST_2_4,
26+
is_sm_at_least_89,
27+
is_sm_at_least_100,
28+
)
2529

2630
torch.manual_seed(2)
2731

@@ -166,6 +170,8 @@ def test_transpose(elem_dtype, fp4_triton):
166170
"""
167171
if elem_dtype != DTYPE_FP4 and fp4_triton:
168172
pytest.skip("unsupported configuration")
173+
elif fp4_triton and is_sm_at_least_100():
174+
pytest.skip("triton does not work yet on CUDA capability 10.0")
169175

170176
M, K = 128, 256
171177
block_size = 32
@@ -205,6 +211,9 @@ def test_view(elem_dtype):
205211

206212

207213
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
214+
@pytest.mark.skipif(
215+
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
216+
)
208217
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
209218
@pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
210219
@pytest.mark.parametrize("all_zeros", [False, True])

torchao/utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,15 @@ def is_sm_at_least_90():
630630
)
631631

632632

633+
# TODO(future PR): rename to 8_9, 9_0, 10_0 instead of 89, 10, 100
634+
def is_sm_at_least_100():
635+
return (
636+
torch.cuda.is_available()
637+
and torch.version.cuda
638+
and torch.cuda.get_device_capability() >= (10, 0)
639+
)
640+
641+
633642
TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev")
634643
TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev")
635644
TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev")

0 commit comments

Comments
 (0)