@@ -41,38 +41,82 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None:
41
41
)
42
42
43
43
44
+ CONTRACTION_TEMPLATE = r"""
45
+ builtin.module{{
46
+ func.func @test(%arg0: {lhs_type}, %arg1: {rhs_type}) -> {res_type} {{
47
+ %cst = arith.constant 0.000000e+00 : f32
48
+ %0 = tensor.empty() : {res_type}
49
+ %1 = linalg.fill ins(%cst : f32) outs(%0 : {res_type}) -> {res_type}
50
+ %2 = linalg.generic {{
51
+ indexing_maps = [
52
+ {lhs_map},
53
+ {rhs_map},
54
+ {res_map}],
55
+ iterator_types = {iterator_types}}}
56
+ {{root_op}}
57
+ ins(%arg0, %arg1 : {lhs_type}, {rhs_type})
58
+ outs(%1 : {res_type}) {{
59
+ ^bb0(%in: f16, %in_0: f16, %out: f32):
60
+ %3 = arith.extf %in : f16 to f32
61
+ %4 = arith.extf %in_0 : f16 to f32
62
+ %5 = arith.mulf %3, %4 : f32
63
+ %6 = arith.addf %out, %5 : f32
64
+ linalg.yield %6 : f32
65
+ }} -> {res_type}
66
+ return %2 : {res_type}
67
+ }}
68
+ }}"""
69
+
70
+
44
71
def test_get_contraction_operation (tuner_ctx : common .TunerContext ) -> None :
45
72
context = tuner_ctx .mlir_ctx
46
- module_str = """
47
- builtin.module{
48
- func.func @test(%arg0: tensor<4x4xf16>, %arg1: tensor<4x4xf16>) -> tensor<4x4xf32> {
49
- %cst = arith.constant 0.000000e+00 : f32
50
- %0 = tensor.empty() : tensor<4x4xf32>
51
- %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x4xf32>) -> tensor<4x4xf32>
52
- %2 = linalg.generic {
53
- indexing_maps = [
54
- affine_map<(d0, d1, d2) -> (d0, d2)>,
55
- affine_map<(d0, d1, d2) -> (d1, d2)>,
56
- affine_map<(d0, d1, d2) -> (d0, d1)>],
57
- iterator_types = ["parallel", "parallel", "reduction"]}
58
- {root_op}
59
- ins(%arg0, %arg1 : tensor<4x4xf16>, tensor<4x4xf16>)
60
- outs(%1 : tensor<4x4xf32>) {
61
- ^bb0(%in: f16, %in_0: f16, %out: f32):
62
- %3 = arith.extf %in : f16 to f32
63
- %4 = arith.extf %in_0 : f16 to f32
64
- %5 = arith.mulf %3, %4 : f32
65
- %6 = arith.addf %out, %5 : f32
66
- linalg.yield %6 : f32
67
- } -> tensor<4x4xf32>
68
- return %2 : tensor<4x4xf32>
69
- }
70
- }"""
71
- module = ir .Module .parse (module_str , context )
73
+
74
+ with ir .Location .unknown ():
75
+ transpose_b_str = CONTRACTION_TEMPLATE .format (
76
+ lhs_type = ir .RankedTensorType .get ([16 , 64 ], ir .F16Type .get ()),
77
+ rhs_type = ir .RankedTensorType .get ([32 , 64 ], ir .F16Type .get ()),
78
+ res_type = ir .RankedTensorType .get ([16 , 32 ], ir .F32Type .get ()),
79
+ lhs_map = "affine_map<(d0, d1, d2) -> (d0, d2)>" ,
80
+ rhs_map = "affine_map<(d0, d1, d2) -> (d1, d2)>" ,
81
+ res_map = "affine_map<(d0, d1, d2) -> (d0, d1)>" ,
82
+ iterator_types = '["parallel", "parallel", "reduction"]' ,
83
+ )
84
+ module = ir .Module .parse (transpose_b_str , context )
72
85
parser = dispatch_parser .ContractionOpInterfaceParser ()
73
86
mmt_op = parser .get_contraction_operation (module )
74
87
assert mmt_op is not None
75
88
assert isinstance (mmt_op .opview , linalg .GenericOp )
89
+ shapes : common .ProblemSize = parser .get_shapes (transpose_b_str .splitlines ())
90
+ assert shapes .matmul_size .B == 1
91
+ assert shapes .matmul_size .M == 16
92
+ assert shapes .matmul_size .N == 32
93
+ assert shapes .matmul_size .K == 64
94
+ assert shapes .lhs_type .shape == [16 , 64 ]
95
+ assert isinstance (shapes .lhs_type .element_type , ir .F16Type )
96
+ assert shapes .rhs_type .shape == [32 , 64 ]
97
+ assert isinstance (shapes .rhs_type .element_type , ir .F16Type )
98
+ assert shapes .res_type .shape == [16 , 32 ]
99
+ assert isinstance (shapes .res_type .element_type , ir .F32Type )
100
+
101
+ with ir .Location .unknown ():
102
+ bmm_transposed_inputs_str = CONTRACTION_TEMPLATE .format (
103
+ lhs_type = ir .RankedTensorType .get ([5 , 8 , 128 ], ir .F16Type .get ()),
104
+ rhs_type = ir .RankedTensorType .get ([128 , 40 , 5 ], ir .F16Type .get ()),
105
+ res_type = ir .RankedTensorType .get ([5 , 40 , 8 ], ir .F32Type .get ()),
106
+ lhs_map = "affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>" ,
107
+ rhs_map = "affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>" ,
108
+ res_map = "affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>" ,
109
+ iterator_types = '["parallel", "parallel", "parallel", "reduction"]' ,
110
+ )
111
+ module = ir .Module .parse (bmm_transposed_inputs_str , context )
112
+ mmt_op = parser .get_contraction_operation (module )
113
+ shapes : common .ProblemSize = parser .get_shapes (
114
+ bmm_transposed_inputs_str .splitlines ()
115
+ )
116
+ assert shapes .matmul_size .B == 5
117
+ assert shapes .matmul_size .M == 8
118
+ assert shapes .matmul_size .N == 40
119
+ assert shapes .matmul_size .K == 128
76
120
77
121
78
122
def test_get_conv_operation (tuner_ctx : common .TunerContext ) -> None :
0 commit comments