25
25
from torchao .sparsity .marlin import inject_24 , marlin_24_workspace , pack_to_marlin_24
26
26
from torchao .utils import TORCH_VERSION_AT_LEAST_2_5 , compute_max_diff
27
27
28
- IS_CUDA = torch .cuda . is_available () and torch . version . cuda
29
- IS_ROCM = torch . cuda . is_available () and torch . version . hip
28
+ if torch .version . hip is not None :
29
+ pytest . skip ( "Skipping the test in ROCm" , allow_module_level = True )
30
30
31
31
try :
32
32
import torchao .ops
@@ -52,7 +52,7 @@ def _create_floatx_inputs(
52
52
fp16_act = torch .rand (BS , IC ).to (dtype ) + 0.5
53
53
return floatx_weight .to (device ), scale .to (device ), fp16_act .to (device )
54
54
55
- @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
55
+ @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
56
56
@parametrize ("ebits,mbits" , [(3 , 2 ), (2 , 2 )])
57
57
@parametrize ("dtype" , [torch .half , torch .bfloat16 ])
58
58
def test_quant_llm_linear (self , ebits , mbits , dtype ):
@@ -82,7 +82,7 @@ def test_quant_llm_linear(self, ebits, mbits, dtype):
82
82
test_utils = test_utils ,
83
83
)
84
84
85
- @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
85
+ @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
86
86
@parametrize ("BS,OC,IC,splitK" , [(1 , 2048 , 4096 , 5 ), (2 , 8192 , 8192 , 6 )])
87
87
@parametrize ("ebits,mbits" , [(3 , 2 ), (2 , 2 )])
88
88
@parametrize ("dtype" , [torch .half , torch .bfloat16 ])
@@ -139,7 +139,7 @@ def make_test_id(param):
139
139
return f"tiles_{ param } "
140
140
141
141
142
- @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
142
+ @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
143
143
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
144
144
@pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = make_test_id )
145
145
def test_unpack_tensor_core_tiled_layout_correctness (shape , inner_k_tiles ):
@@ -157,7 +157,7 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
157
157
158
158
159
159
# TODO: Fix "test_aot_dispatch_dynamic" test failure
160
- @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
160
+ @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
161
161
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
162
162
@pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = make_test_id )
163
163
def test_unpack_tensor_core_tiled_layout_op (shape , inner_k_tiles ):
@@ -203,7 +203,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
203
203
return dq .reshape (n , k )
204
204
205
205
206
- @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
206
+ @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
207
207
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
208
208
@pytest .mark .parametrize (
209
209
"shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str
@@ -271,7 +271,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(
271
271
272
272
273
273
# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
274
- @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
274
+ @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
275
275
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
276
276
@pytest .mark .parametrize (
277
277
"shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str
@@ -337,7 +337,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(
337
337
assert diff_op_ao < 1e-1
338
338
339
339
340
- @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
340
+ @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
341
341
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
342
342
@pytest .mark .parametrize (
343
343
"shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str
@@ -448,7 +448,7 @@ def reshape_w(w):
448
448
)
449
449
450
450
451
- @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
451
+ @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
452
452
@pytest .mark .parametrize (
453
453
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors" ,
454
454
MARLIN_TEST_PARAMS ,
@@ -538,7 +538,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto
538
538
)
539
539
540
540
541
- @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
541
+ @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
542
542
@pytest .mark .parametrize (
543
543
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors" ,
544
544
MARLIN_TEST_PARAMS ,
@@ -617,27 +617,5 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact
617
617
)
618
618
619
619
620
- @pytest .mark .skipif (not IS_ROCM , reason = "ROCm not available" )
621
- def test_swizzle_mm ():
622
- test_utils = [
623
- "test_schema" ,
624
- "test_autograd_registration" ,
625
- "test_faketensor" ,
626
- ]
627
-
628
- # TODO: Figure out why test fails unless torch >= 2.5
629
- if TORCH_VERSION_AT_LEAST_2_5 :
630
- test_utils .append ("test_aot_dispatch_dynamic" )
631
-
632
- mat1 = torch .randint (0 , 16 , dtype = torch .float , size = (16 ,32 ), device = "cuda" )
633
- mat2 = torch .randint (0 , 16 , dtype = torch .float , size = (32 ,16 ), device = "cuda" )
634
-
635
- opcheck (
636
- torch .ops .torchao .swizzle_mm ,
637
- (mat1 , mat2 , False , False ),
638
- test_utils = test_utils ,
639
- )
640
-
641
-
642
620
if __name__ == "__main__" :
643
621
pytest .main (sys .argv )
0 commit comments