@@ -570,6 +570,7 @@ def test_per_token_linear_cpu(self):
570
570
self ._test_per_token_linear_impl ("cpu" , dtype )
571
571
572
572
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
573
+ @skip_if_rocm ("ROCm enablement in progress" )
573
574
def test_per_token_linear_cuda (self ):
574
575
for dtype in (torch .float32 , torch .float16 , torch .bfloat16 ):
575
576
self ._test_per_token_linear_impl ("cuda" , dtype )
@@ -688,6 +689,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
688
689
@parameterized .expand (COMMON_DEVICE_DTYPE )
689
690
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
690
691
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
692
+ @skip_if_rocm ("ROCm enablement in progress" )
691
693
def test_dequantize_int4_weight_only_quant_subclass (self , device , dtype ):
692
694
if device == "cpu" :
693
695
self .skipTest (f"Temporarily skipping for { device } " )
@@ -707,6 +709,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
707
709
@parameterized .expand (COMMON_DEVICE_DTYPE )
708
710
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
709
711
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
712
+ @skip_if_rocm ("ROCm enablement in progress" )
710
713
def test_dequantize_int4_weight_only_quant_subclass_grouped (self , device , dtype ):
711
714
if device == "cpu" :
712
715
self .skipTest (f"Temporarily skipping for { device } " )
0 commit comments