@@ -129,9 +129,16 @@ def test_quant_llm_linear_correctness(
129
129
TEST_CONFIGS_DEQUANT = list (itertools .product (SHAPES , INNERKTILES , QGROUP_SIZES ))
130
130
131
131
132
+ def make_test_id (param ):
133
+ if isinstance (param , tuple ) and len (param ) == 2 : # This is a shape
134
+ return f"shape_{ param [0 ]} x{ param [1 ]} "
135
+ else : # This is inner_k_tiles
136
+ return f"tiles_{ param } "
137
+
138
+
132
139
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
133
140
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
134
- @pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = str )
141
+ @pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = make_test_id )
135
142
def test_unpack_tensor_core_tiled_layout_correctness (shape , inner_k_tiles ):
136
143
N , K = shape
137
144
assert K % (inner_k_tiles * kTileSizeK ) == 0 and N % kTileSizeN == 0
@@ -149,7 +156,7 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
149
156
# TODO: Fix "test_aot_dispatch_dynamic" test failure
150
157
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
151
158
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
152
- @pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = str )
159
+ @pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = make_test_id )
153
160
def test_unpack_tensor_core_tiled_layout_op (shape , inner_k_tiles ):
154
161
test_utils = [
155
162
"test_schema" ,
0 commit comments