@@ -570,7 +570,6 @@ def test_per_token_linear_cpu(self):
570
570
self ._test_per_token_linear_impl ("cpu" , dtype )
571
571
572
572
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
573
- @skip_if_rocm ("ROCm development in progress" )
574
573
def test_per_token_linear_cuda (self ):
575
574
for dtype in (torch .float32 , torch .float16 , torch .bfloat16 ):
576
575
self ._test_per_token_linear_impl ("cuda" , dtype )
@@ -689,7 +688,6 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
689
688
@parameterized .expand (COMMON_DEVICE_DTYPE )
690
689
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
691
690
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
692
- @skip_if_rocm ("ROCm development in progress" )
693
691
def test_dequantize_int4_weight_only_quant_subclass (self , device , dtype ):
694
692
if device == "cpu" :
695
693
self .skipTest (f"Temporarily skipping for { device } " )
@@ -709,7 +707,6 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
709
707
@parameterized .expand (COMMON_DEVICE_DTYPE )
710
708
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
711
709
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
712
- @skip_if_rocm ("ROCm development in progress" )
713
710
def test_dequantize_int4_weight_only_quant_subclass_grouped (self , device , dtype ):
714
711
if device == "cpu" :
715
712
self .skipTest (f"Temporarily skipping for { device } " )
@@ -903,7 +900,6 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
903
900
@parameterized .expand (COMMON_DEVICE_DTYPE )
904
901
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
905
902
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
906
- @skip_if_rocm ("ROCm development in progress" )
907
903
def test_int4_weight_only_quant_subclass (self , device , dtype ):
908
904
if device == "cpu" :
909
905
self .skipTest (f"Temporarily skipping for { device } " )
@@ -923,7 +919,6 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
923
919
@parameterized .expand (COMMON_DEVICE_DTYPE )
924
920
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_3 , "int4 requires torch nightly." )
925
921
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
926
- @skip_if_rocm ("ROCm development in progress" )
927
922
def test_int4_weight_only_quant_subclass_grouped (self , device , dtype ):
928
923
if dtype != torch .bfloat16 :
929
924
self .skipTest (f"Fails for { dtype } " )
@@ -1827,7 +1822,7 @@ def test_autoquant_int4wo(self, device, dtype):
1827
1822
self .assertGreater (compute_error (ref , out ), 20 )
1828
1823
1829
1824
@parameterized .expand (COMMON_DEVICE_DTYPE )
1830
- @unittest .skipIf (not torch . cuda . is_available (), "Need CUDA available " )
1825
+ @unittest .skipIf (not is_sm_at_least_90 (), "Need cuda arch greater than SM90 " )
1831
1826
@unittest .skipIf (
1832
1827
not TORCH_VERSION_AT_LEAST_2_5 , "autoquant int4 option requires 2.5+."
1833
1828
)
0 commit comments