Skip to content

Commit bb36700

Browse files
committed
skip fsdp2 test for ROCm
1 parent aec1039 commit bb36700

File tree

4 files changed

+10
-0
lines changed

4 files changed

+10
-0
lines changed

test/float8/test_fsdp2/test_fsdp2.py

+3
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@
4343
if not is_sm_at_least_89():
4444
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)
4545

46+
if torch.version.hip is not None:
47+
pytest.skip("ROCm enablement in progress", allow_module_level=True)
48+
4649

4750
class TestFloat8Common:
4851
def broadcast_module(self, module: nn.Module) -> None:

test/integration/test_integration.py

+3
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,7 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
903903
@parameterized.expand(COMMON_DEVICE_DTYPE)
904904
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
905905
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
906+
@skip_if_rocm("ROCm enablement in progress")
906907
def test_int4_weight_only_quant_subclass(self, device, dtype):
907908
if device == "cpu":
908909
self.skipTest(f"Temporarily skipping for {device}")
@@ -922,6 +923,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
922923
@parameterized.expand(COMMON_DEVICE_DTYPE)
923924
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
924925
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
926+
@skip_if_rocm("ROCm enablement in progress")
925927
def test_int4_weight_only_quant_subclass_grouped(self, device, dtype):
926928
if dtype != torch.bfloat16:
927929
self.skipTest(f"Fails for {dtype}")
@@ -1075,6 +1077,7 @@ def test_gemlite_layout(self, device, dtype):
10751077
@parameterized.expand(COMMON_DEVICE_DTYPE)
10761078
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
10771079
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
1080+
@skip_if_rocm("ROCm enablement in progress")
10781081
def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
10791082
if device == "cpu":
10801083
self.skipTest(f"Temporarily skipping for {device}")

test/kernel/test_fused_kernels.py

+3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import torch
1212
from galore_test_utils import get_kernel, make_copy, make_data
1313

14+
from torchao.utils import skip_if_rocm
15+
1416
torch.manual_seed(0)
1517
MAX_DIFF_no_tf32 = 1e-5
1618
MAX_DIFF_tf32 = 1e-3
@@ -104,6 +106,7 @@ def run_test(kernel, exp_avg, exp_avg2, grad, proj_matrix, params, allow_tf32):
104106

105107
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
106108
@pytest.mark.parametrize("kernel, dtype, M, N, rank, allow_tf32", TEST_CONFIGS)
109+
@skip_if_rocm("ROCm enablement in progress")
107110
def test_galore_fused_kernels(kernel, dtype, M, N, rank, allow_tf32):
108111
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
109112

test/prototype/test_low_bit_optim.py

+1
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ def world_size(self) -> int:
386386
not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required."
387387
)
388388
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
389+
@skip_if_rocm("ROCm enablement in progress")
389390
def test_fsdp2(self):
390391
optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit]
391392
if torch.cuda.get_device_capability() >= (8, 9):

0 commit comments

Comments
 (0)