2525from torchao .sparsity .marlin import inject_24 , marlin_24_workspace , pack_to_marlin_24
2626from torchao .utils import TORCH_VERSION_AT_LEAST_2_5 , compute_max_diff
2727
28- if torch .version . hip is not None :
29- pytest . skip ( "Skipping the test in ROCm" , allow_module_level = True )
28+ IS_CUDA = torch .cuda . is_available () and torch . version . cuda
29+ IS_ROCM = torch . cuda . is_available () and torch . version . hip
3030
3131try :
3232 import torchao .ops
@@ -52,7 +52,7 @@ def _create_floatx_inputs(
5252 fp16_act = torch .rand (BS , IC ).to (dtype ) + 0.5
5353 return floatx_weight .to (device ), scale .to (device ), fp16_act .to (device )
5454
55- @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
55+ @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
5656 @parametrize ("ebits,mbits" , [(3 , 2 ), (2 , 2 )])
5757 @parametrize ("dtype" , [torch .half , torch .bfloat16 ])
5858 def test_quant_llm_linear (self , ebits , mbits , dtype ):
@@ -82,7 +82,7 @@ def test_quant_llm_linear(self, ebits, mbits, dtype):
8282 test_utils = test_utils ,
8383 )
8484
85- @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
85+ @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
8686 @parametrize ("BS,OC,IC,splitK" , [(1 , 2048 , 4096 , 5 ), (2 , 8192 , 8192 , 6 )])
8787 @parametrize ("ebits,mbits" , [(3 , 2 ), (2 , 2 )])
8888 @parametrize ("dtype" , [torch .half , torch .bfloat16 ])
@@ -139,7 +139,7 @@ def make_test_id(param):
139139 return f"tiles_{ param } "
140140
141141
142- @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
142+ @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
143143# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
144144@pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = make_test_id )
145145def test_unpack_tensor_core_tiled_layout_correctness (shape , inner_k_tiles ):
@@ -157,7 +157,7 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
157157
158158
159159# TODO: Fix "test_aot_dispatch_dynamic" test failure
160- @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
160+ @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
161161# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
162162@pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = make_test_id )
163163def test_unpack_tensor_core_tiled_layout_op (shape , inner_k_tiles ):
@@ -203,7 +203,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
203203 return dq .reshape (n , k )
204204
205205
206- @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
206+ @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
207207# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
208208@pytest .mark .parametrize (
209209 "shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str
@@ -271,7 +271,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(
271271
272272
273273# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
274- @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
274+ @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
275275# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
276276@pytest .mark .parametrize (
277277 "shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str
@@ -337,7 +337,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(
337337 assert diff_op_ao < 1e-1
338338
339339
340- @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
340+ @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
341341# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
342342@pytest .mark .parametrize (
343343 "shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str
@@ -448,7 +448,7 @@ def reshape_w(w):
448448 )
449449
450450
451- @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
451+ @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
452452@pytest .mark .parametrize (
453453 "batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors" ,
454454 MARLIN_TEST_PARAMS ,
@@ -538,7 +538,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto
538538)
539539
540540
541- @pytest .mark .skipif (not torch . cuda . is_available () , reason = "CUDA not available" )
541+ @pytest .mark .skipif (not IS_CUDA , reason = "CUDA not available" )
542542@pytest .mark .parametrize (
543543 "batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors" ,
544544 MARLIN_TEST_PARAMS ,
@@ -617,5 +617,27 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact
617617 )
618618
619619
620+ @pytest .mark .skipif (not IS_ROCM , reason = "ROCm not available" )
621+ def test_swizzle_mm ():
622+ test_utils = [
623+ "test_schema" ,
624+ "test_autograd_registration" ,
625+ "test_faketensor" ,
626+ ]
627+
628+ # TODO: Figure out why test fails unless torch >= 2.5
629+ if TORCH_VERSION_AT_LEAST_2_5 :
630+ test_utils .append ("test_aot_dispatch_dynamic" )
631+
632+ mat1 = torch .randint (0 , 16 , dtype = torch .float , size = (16 , 32 ), device = "cuda" )
633+ mat2 = torch .randint (0 , 16 , dtype = torch .float , size = (32 , 16 ), device = "cuda" )
634+
635+ opcheck (
636+ torch .ops .torchao .swizzle_mm ,
637+ (mat1 , mat2 , False , False ),
638+ test_utils = test_utils ,
639+ )
640+
641+
620642if __name__ == "__main__" :
621643 pytest .main (sys .argv )
0 commit comments