@@ -903,6 +903,7 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
903
903
@parameterized .expand (COMMON_DEVICE_DTYPE )
904
904
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
905
905
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
906
+ @skip_if_rocm ("ROCm enablement in progress" )
906
907
def test_int4_weight_only_quant_subclass (self , device , dtype ):
907
908
if device == "cpu" :
908
909
self .skipTest (f"Temporarily skipping for { device } " )
@@ -922,6 +923,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
922
923
@parameterized .expand (COMMON_DEVICE_DTYPE )
923
924
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
924
925
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
926
+ @skip_if_rocm ("ROCm enablement in progress" )
925
927
def test_int4_weight_only_quant_subclass_grouped (self , device , dtype ):
926
928
if dtype != torch .bfloat16 :
927
929
self .skipTest (f"Fails for { dtype } " )
@@ -1075,6 +1077,7 @@ def test_gemlite_layout(self, device, dtype):
1075
1077
@parameterized .expand (COMMON_DEVICE_DTYPE )
1076
1078
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
1077
1079
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
1080
+ @skip_if_rocm ("ROCm enablement in progress" )
1078
1081
def test_int4_weight_only_quant_subclass_api_grouped (self , device , dtype ):
1079
1082
if device == "cpu" :
1080
1083
self .skipTest (f"Temporarily skipping for { device } " )
0 commit comments