|
23 | 23 | is_sm_at_least_89,
|
24 | 24 | )
|
25 | 25 |
|
| 26 | +is_cusparselt_available = ( |
| 27 | + hasattr(torch.backends, "cusparselt") and torch.backends.cusparselt.is_available() |
| 28 | +) |
| 29 | + |
26 | 30 |
|
27 | 31 | def get_quantization_functions(
|
28 | 32 | do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False
|
@@ -91,7 +95,8 @@ def test_tensor_core_layout_transpose(self):
|
91 | 95 |
|
92 | 96 | @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
|
93 | 97 | @common_utils.parametrize(
|
94 |
| - "apply_quant", get_quantization_functions(True, True, "cuda", True) |
| 98 | + "apply_quant", |
| 99 | + get_quantization_functions(is_cusparselt_available, True, "cuda", True), |
95 | 100 | )
|
96 | 101 | def test_weights_only(self, apply_quant):
|
97 | 102 | linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
|
@@ -168,7 +173,9 @@ def apply_uint6_weight_only_quant(linear):
|
168 | 173 |
|
169 | 174 | deregister_aqt_quantized_linear_dispatch(dispatch_condition)
|
170 | 175 |
|
171 |
| - @common_utils.parametrize("apply_quant", get_quantization_functions(True, True)) |
| 176 | + @common_utils.parametrize( |
| 177 | + "apply_quant", get_quantization_functions(is_cusparselt_available, True) |
| 178 | + ) |
172 | 179 | @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
|
173 | 180 | def test_print_quantized_module(self, apply_quant):
|
174 | 181 | linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
|
|
0 commit comments