93
93
except ModuleNotFoundError :
94
94
has_gemlite = False
95
95
96
+ from test_utils import skip_if_rocm
97
+
96
98
logger = logging .getLogger ("INFO" )
97
99
98
100
torch .manual_seed (0 )
@@ -569,6 +571,7 @@ def test_per_token_linear_cpu(self):
569
571
self ._test_per_token_linear_impl ("cpu" , dtype )
570
572
571
573
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
574
+ @skip_if_rocm ("ROCm development in progress" )
572
575
def test_per_token_linear_cuda (self ):
573
576
for dtype in (torch .float32 , torch .float16 , torch .bfloat16 ):
574
577
self ._test_per_token_linear_impl ("cuda" , dtype )
@@ -687,6 +690,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
687
690
@parameterized .expand (COMMON_DEVICE_DTYPE )
688
691
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
689
692
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
693
+ @skip_if_rocm ("ROCm development in progress" )
690
694
def test_dequantize_int4_weight_only_quant_subclass (self , device , dtype ):
691
695
if device == "cpu" :
692
696
self .skipTest (f"Temporarily skipping for { device } " )
@@ -706,6 +710,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
706
710
@parameterized .expand (COMMON_DEVICE_DTYPE )
707
711
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
708
712
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
713
+ @skip_if_rocm ("ROCm development in progress" )
709
714
def test_dequantize_int4_weight_only_quant_subclass_grouped (self , device , dtype ):
710
715
if device == "cpu" :
711
716
self .skipTest (f"Temporarily skipping for { device } " )
@@ -899,6 +904,7 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
899
904
@parameterized .expand (COMMON_DEVICE_DTYPE )
900
905
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
901
906
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
907
+ @skip_if_rocm ("ROCm development in progress" )
902
908
def test_int4_weight_only_quant_subclass (self , device , dtype ):
903
909
if device == "cpu" :
904
910
self .skipTest (f"Temporarily skipping for { device } " )
@@ -918,6 +924,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
918
924
@parameterized .expand (COMMON_DEVICE_DTYPE )
919
925
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
920
926
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
927
+ @skip_if_rocm ("ROCm development in progress" )
921
928
def test_int4_weight_only_quant_subclass_grouped (self , device , dtype ):
922
929
if dtype != torch .bfloat16 :
923
930
self .skipTest (f"Fails for { dtype } " )
0 commit comments