Skip to content

Commit 083018d

Browse files
committed
Fix contraction matching bug, add tests
Signed-off-by: Max Dawkins <[email protected]>
1 parent 2876528 commit 083018d

File tree

3 files changed

+82
-29
lines changed

3 files changed

+82
-29
lines changed

tuner/tuner/dispatch_parser.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,20 @@ def get_shapes(self, template: list[str]) -> ProblemSize:
117117
assert False, f"contraction op not found"
118118
cdims = matcher.contraction_dimensions
119119
assert cdims, "no contraction dimensions"
120+
assert len(cdims.batch) <= 1, f"must have at most 1 batch dimension"
120121
assert len(cdims.m) == 1, f"must have a single m dimension"
121122
assert len(cdims.n) == 1, f"must have a single n dimension"
122123
assert len(cdims.k) == 1, f"must have a single k dimension"
123124
lhs_type = ir.RankedTensorType(contraction_op.operands[0].type)
124125
rhs_type = ir.RankedTensorType(contraction_op.operands[1].type)
125126
res_type = ir.RankedTensorType(contraction_op.operands[2].type)
126127
matmul_size = MatmulSize(
127-
lhs_type.shape[0],
128-
rhs_type.shape[0],
129-
lhs_type.shape[1],
128+
lhs_type.shape[matcher.lhs_dims.index(cdims.m[0])],
129+
rhs_type.shape[matcher.rhs_dims.index(cdims.n[0])],
130+
lhs_type.shape[matcher.lhs_dims.index(cdims.k[0])],
130131
)
132+
if len(cdims.batch) == 1:
133+
matmul_size.B = lhs_type.shape[matcher.lhs_dims.index(cdims.batch[0])]
131134
return ProblemSize(
132135
matmul_size,
133136
lhs_type=ShapedType(lhs_type.shape, lhs_type.element_type),

tuner/tuner/dispatch_parser_test.py

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,38 +41,82 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None:
4141
)
4242

4343

44+
CONTRACTION_TEMPLATE = r"""
45+
builtin.module{{
46+
func.func @test(%arg0: {lhs_type}, %arg1: {rhs_type}) -> {res_type} {{
47+
%cst = arith.constant 0.000000e+00 : f32
48+
%0 = tensor.empty() : {res_type}
49+
%1 = linalg.fill ins(%cst : f32) outs(%0 : {res_type}) -> {res_type}
50+
%2 = linalg.generic {{
51+
indexing_maps = [
52+
{lhs_map},
53+
{rhs_map},
54+
{res_map}],
55+
iterator_types = {iterator_types}}}
56+
{{root_op}}
57+
ins(%arg0, %arg1 : {lhs_type}, {rhs_type})
58+
outs(%1 : {res_type}) {{
59+
^bb0(%in: f16, %in_0: f16, %out: f32):
60+
%3 = arith.extf %in : f16 to f32
61+
%4 = arith.extf %in_0 : f16 to f32
62+
%5 = arith.mulf %3, %4 : f32
63+
%6 = arith.addf %out, %5 : f32
64+
linalg.yield %6 : f32
65+
}} -> {res_type}
66+
return %2 : {res_type}
67+
}}
68+
}}"""
69+
70+
4471
def test_get_contraction_operation(tuner_ctx: common.TunerContext) -> None:
4572
context = tuner_ctx.mlir_ctx
46-
module_str = """
47-
builtin.module{
48-
func.func @test(%arg0: tensor<4x4xf16>, %arg1: tensor<4x4xf16>) -> tensor<4x4xf32> {
49-
%cst = arith.constant 0.000000e+00 : f32
50-
%0 = tensor.empty() : tensor<4x4xf32>
51-
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
52-
%2 = linalg.generic {
53-
indexing_maps = [
54-
affine_map<(d0, d1, d2) -> (d0, d2)>,
55-
affine_map<(d0, d1, d2) -> (d1, d2)>,
56-
affine_map<(d0, d1, d2) -> (d0, d1)>],
57-
iterator_types = ["parallel", "parallel", "reduction"]}
58-
{root_op}
59-
ins(%arg0, %arg1 : tensor<4x4xf16>, tensor<4x4xf16>)
60-
outs(%1 : tensor<4x4xf32>) {
61-
^bb0(%in: f16, %in_0: f16, %out: f32):
62-
%3 = arith.extf %in : f16 to f32
63-
%4 = arith.extf %in_0 : f16 to f32
64-
%5 = arith.mulf %3, %4 : f32
65-
%6 = arith.addf %out, %5 : f32
66-
linalg.yield %6 : f32
67-
} -> tensor<4x4xf32>
68-
return %2 : tensor<4x4xf32>
69-
}
70-
}"""
71-
module = ir.Module.parse(module_str, context)
73+
74+
with ir.Location.unknown():
75+
transpose_b_str = CONTRACTION_TEMPLATE.format(
76+
lhs_type=ir.RankedTensorType.get([16, 64], ir.F16Type.get()),
77+
rhs_type=ir.RankedTensorType.get([32, 64], ir.F16Type.get()),
78+
res_type=ir.RankedTensorType.get([16, 32], ir.F32Type.get()),
79+
lhs_map="affine_map<(d0, d1, d2) -> (d0, d2)>",
80+
rhs_map="affine_map<(d0, d1, d2) -> (d1, d2)>",
81+
res_map="affine_map<(d0, d1, d2) -> (d0, d1)>",
82+
iterator_types='["parallel", "parallel", "reduction"]',
83+
)
84+
module = ir.Module.parse(transpose_b_str, context)
7285
parser = dispatch_parser.ContractionOpInterfaceParser()
7386
mmt_op = parser.get_contraction_operation(module)
7487
assert mmt_op is not None
7588
assert isinstance(mmt_op.opview, linalg.GenericOp)
89+
shapes: common.ProblemSize = parser.get_shapes(transpose_b_str.splitlines())
90+
assert shapes.matmul_size.B == 1
91+
assert shapes.matmul_size.M == 16
92+
assert shapes.matmul_size.N == 32
93+
assert shapes.matmul_size.K == 64
94+
assert shapes.lhs_type.shape == [16, 64]
95+
assert isinstance(shapes.lhs_type.element_type, ir.F16Type)
96+
assert shapes.rhs_type.shape == [32, 64]
97+
assert isinstance(shapes.rhs_type.element_type, ir.F16Type)
98+
assert shapes.res_type.shape == [16, 32]
99+
assert isinstance(shapes.res_type.element_type, ir.F32Type)
100+
101+
with ir.Location.unknown():
102+
bmm_transposed_inputs_str = CONTRACTION_TEMPLATE.format(
103+
lhs_type=ir.RankedTensorType.get([5, 8, 128], ir.F16Type.get()),
104+
rhs_type=ir.RankedTensorType.get([128, 40, 5], ir.F16Type.get()),
105+
res_type=ir.RankedTensorType.get([5, 40, 8], ir.F32Type.get()),
106+
lhs_map="affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>",
107+
rhs_map="affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>",
108+
res_map="affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>",
109+
iterator_types='["parallel", "parallel", "parallel", "reduction"]',
110+
)
111+
module = ir.Module.parse(bmm_transposed_inputs_str, context)
112+
mmt_op = parser.get_contraction_operation(module)
113+
shapes: common.ProblemSize = parser.get_shapes(
114+
bmm_transposed_inputs_str.splitlines()
115+
)
116+
assert shapes.matmul_size.B == 5
117+
assert shapes.matmul_size.M == 8
118+
assert shapes.matmul_size.N == 40
119+
assert shapes.matmul_size.K == 128
76120

77121

78122
def test_get_conv_operation(tuner_ctx: common.TunerContext) -> None:

tuner/tuner/op_matchers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ class ContractionOpInterfaceMatcher(GenericOpMatcher):
125125
def __init__(self):
126126
super().__init__()
127127
self.contraction_dimensions: Optional[ContractionDimensions] = None
128+
self.lhs_dims: Optional[list[int]] = None
129+
self.rhs_dims: Optional[list[int]] = None
130+
self.res_dims: Optional[list[int]] = None
128131

129132
def match_operands(self, operands: ir.OpOperandList) -> bool:
130133
if len(operands) != 3:
@@ -169,4 +172,7 @@ def match_indexing_maps(self, maps: list[ir.AffineMap]) -> bool:
169172
n=n_dims,
170173
k=k_dims,
171174
)
175+
self.lhs_dims = lhs_dims
176+
self.rhs_dims = rhs_dims
177+
self.res_dims = res_dims
172178
return True

0 commit comments

Comments
 (0)