33
33
float8_dynamic_activation_float8_weight ,
34
34
float8_static_activation_float8_weight ,
35
35
float8_weight_only ,
36
+ int4_dynamic_activation_int4_weight ,
36
37
int4_weight_only ,
37
38
int8_dynamic_activation_int4_weight ,
38
39
int8_dynamic_activation_int8_weight ,
50
51
TORCH_VERSION_AT_LEAST_2_5 ,
51
52
TORCH_VERSION_AT_LEAST_2_6 ,
52
53
is_sm_at_least_89 ,
54
+ is_sm_at_least_90 ,
53
55
unwrap_tensor_subclass ,
54
56
)
55
57
@@ -798,6 +800,10 @@ def test_int4wo_cpu(self, dtype, x_dim):
798
800
float8_weight_only (),
799
801
float8_dynamic_activation_float8_weight (),
800
802
float8_static_activation_float8_weight (scale = torch .tensor ([1.0 ])),
803
+ int4_dynamic_activation_int4_weight (),
804
+ int8_dynamic_activation_int8_weight (),
805
+ int8_dynamic_activation_int4_weight (),
806
+ int8_weight_only (),
801
807
],
802
808
)
803
809
def test_workflow_e2e_numerics (self , config ):
@@ -816,6 +822,11 @@ def test_workflow_e2e_numerics(self, config):
816
822
and not is_sm_at_least_89 ()
817
823
):
818
824
return unittest .skip ("requires CUDA capability 8.9 or greater" )
825
+ elif (
826
+ isinstance (config , int4_dynamic_activation_int4_weight )
827
+ and is_sm_at_least_90 ()
828
+ ):
829
+ return unittest .skip ("only supported on CUDA capability 8.9, not greater" )
819
830
820
831
# scale has to be moved to cuda here because the parametrization init
821
832
# code happens before gating for cuda availability
0 commit comments