Skip to content

Commit 7d055af

Browse files
caojoshuajoker-eph
andauthored
[mlir][Symbol] Add verification that symbol's parent is a SymbolTable (#80590)
Following the discussion in https://discourse.llvm.org/t/symboltable-and-symbol-parent-child-relationship/75446, we should enforce that a symbol's immediate parent is a symbol table. I changed some tests to pass the verification. In most cases, we can wrap the func with a module, change the func to another op with regions i.e. scf.if, or change the expected error message. --------- Co-authored-by: Mehdi Amini <[email protected]>
1 parent d53043f commit 7d055af

13 files changed

+68
-85
lines changed

mlir/include/mlir/IR/SymbolInterfaces.td

+5
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,11 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
171171
if (concreteOp.isDeclaration() && concreteOp.isPublic())
172172
return concreteOp.emitOpError("symbol declaration cannot have public "
173173
"visibility");
174+
auto parent = $_op->getParentOp();
175+
if (parent && !parent->hasTrait<OpTrait::SymbolTable>() && parent->isRegistered()) {
176+
return concreteOp.emitOpError("symbol's parent must have the SymbolTable "
177+
"trait");
178+
}
174179
return success();
175180
}];
176181

mlir/test/Dialect/LLVMIR/global.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ llvm.mlir.global internal constant @constant(37.0) : !llvm.label
132132
// -----
133133

134134
func.func @foo() {
135-
// expected-error @+1 {{must appear at the module level}}
135+
// expected-error @+1 {{op symbol's parent must have the SymbolTable trait}}
136136
llvm.mlir.global internal @bar(42) : i32
137137

138138
return

mlir/test/Dialect/Linalg/transform-op-replace.mlir

+4-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ module attributes {transform.with_named_sequence} {
1212
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
1313
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
1414
transform.structured.replace %0 {
15-
func.func @foo() {
16-
"dummy_op"() : () -> ()
15+
builtin.module {
16+
func.func @foo() {
17+
"dummy_op"() : () -> ()
18+
}
1719
}
1820
} : (!transform.any_op) -> !transform.any_op
1921
transform.yield

mlir/test/Dialect/Transform/ops-invalid.mlir

+1-2
Original file line numberDiff line numberDiff line change
@@ -433,10 +433,9 @@ module {
433433
// -----
434434

435435
module attributes { transform.with_named_sequence} {
436-
// expected-note @below {{ancestor transform op}}
437436
transform.sequence failures(suppress) {
438437
^bb0(%arg0: !transform.any_op):
439-
// expected-error @below {{cannot be defined inside another transform op}}
438+
// expected-error @below {{op symbol's parent must have the SymbolTable trai}}
440439
transform.named_sequence @nested() {
441440
transform.yield
442441
}

mlir/test/IR/invalid-func-op.mlir

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func.func @func_op() {
3131
// -----
3232

3333
func.func @func_op() {
34-
// expected-error@+1 {{entry block must have 1 arguments to match function signature}}
34+
// expected-error@+1 {{op symbol's parent must have the SymbolTable trait}}
3535
func.func @mixed_named_arguments(f32) {
3636
^entry:
3737
return
@@ -42,7 +42,7 @@ func.func @func_op() {
4242
// -----
4343

4444
func.func @func_op() {
45-
// expected-error@+1 {{type of entry block argument #0('i32') must match the type of the corresponding argument in function signature('f32')}}
45+
// expected-error@+1 {{op symbol's parent must have the SymbolTable trait}}
4646
func.func @mixed_named_arguments(f32) {
4747
^entry(%arg : i32):
4848
return

mlir/test/IR/region.mlir

+3-4
Original file line numberDiff line numberDiff line change
@@ -87,18 +87,17 @@ func.func @named_region_has_wrong_number_of_blocks() {
8787
// CHECK: test.single_no_terminator_op
8888
"test.single_no_terminator_op"() (
8989
{
90-
func.func @foo1() { return }
91-
func.func @foo2() { return }
90+
%foo = arith.constant 1 : i32
9291
}
9392
) : () -> ()
9493

9594
// CHECK: test.variadic_no_terminator_op
9695
"test.variadic_no_terminator_op"() (
9796
{
98-
func.func @foo1() { return }
97+
%foo = arith.constant 1 : i32
9998
},
10099
{
101-
func.func @foo2() { return }
100+
%bar = arith.constant 1 : i32
102101
}
103102
) : () -> ()
104103

mlir/test/IR/traits.mlir

+16-17
Original file line numberDiff line numberDiff line change
@@ -572,15 +572,13 @@ func.func @failedHasDominanceScopeOutsideDominanceFreeScope() -> () {
572572

573573
// Ensure that SSACFG regions of operations in GRAPH regions are
574574
// checked for dominance
575-
func.func @illegalInsideDominanceFreeScope() -> () {
575+
func.func @illegalInsideDominanceFreeScope(%cond: i1) -> () {
576576
test.graph_region {
577-
func.func @test() -> i1 {
578-
^bb1:
577+
scf.if %cond {
579578
// expected-error @+1 {{operand #0 does not dominate this use}}
580579
%2:3 = "bar"(%1) : (i64) -> (i1,i1,i1)
581580
// expected-note @+1 {{operand defined here}}
582-
%1 = "baz"(%2#0) : (i1) -> (i64)
583-
return %2#1 : i1
581+
%1 = "baz"(%2#0) : (i1) -> (i64)
584582
}
585583
"terminator"() : () -> ()
586584
}
@@ -591,20 +589,21 @@ func.func @illegalInsideDominanceFreeScope() -> () {
591589

592590
// Ensure that SSACFG regions of operations in GRAPH regions are
593591
// checked for dominance
594-
func.func @illegalCDFGInsideDominanceFreeScope() -> () {
592+
func.func @illegalCFGInsideDominanceFreeScope(%cond: i1) -> () {
595593
test.graph_region {
596-
func.func @test() -> i1 {
597-
^bb1:
598-
// expected-error @+1 {{operand #0 does not dominate this use}}
599-
%2:3 = "bar"(%1) : (i64) -> (i1,i1,i1)
600-
cf.br ^bb4
601-
^bb2:
602-
cf.br ^bb2
603-
^bb4:
604-
%1 = "foo"() : ()->i64 // expected-note {{operand defined here}}
605-
return %2#1 : i1
594+
scf.if %cond {
595+
"test.ssacfg_region"() ({
596+
^bb1:
597+
// expected-error @+1 {{operand #0 does not dominate this use}}
598+
%2:3 = "bar"(%1) : (i64) -> (i1,i1,i1)
599+
cf.br ^bb4
600+
^bb2:
601+
cf.br ^bb2
602+
^bb4:
603+
%1 = "foo"() : ()->i64 // expected-note {{operand defined here}}
604+
}) : () -> ()
606605
}
607-
"terminator"() : () -> ()
606+
"terminator"() : () -> ()
608607
}
609608
return
610609
}

mlir/test/Transforms/canonicalize-dce.mlir

+7-7
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,15 @@ func.func @f(%arg0: f32, %pred: i1) {
7777

7878
// Test case: Recursively DCE into enclosed regions.
7979

80-
// CHECK: func @f(%arg0: f32)
81-
// CHECK-NEXT: func @g(%arg1: f32)
82-
// CHECK-NEXT: return
80+
// CHECK: func.func @f(%arg0: f32)
81+
// CHECK-NOT: arith.addf
8382

8483
func.func @f(%arg0: f32) {
85-
func.func @g(%arg1: f32) {
86-
%0 = "arith.addf"(%arg1, %arg1) : (f32, f32) -> f32
87-
return
88-
}
84+
"test.region"() (
85+
{
86+
%0 = "arith.addf"(%arg0, %arg0) : (f32, f32) -> f32
87+
}
88+
) : () -> ()
8989
return
9090
}
9191

mlir/test/Transforms/canonicalize.mlir

+6-7
Original file line numberDiff line numberDiff line change
@@ -424,16 +424,15 @@ func.func @write_only_alloca_fold(%v: f32) {
424424
// CHECK-LABEL: func @dead_block_elim
425425
func.func @dead_block_elim() {
426426
// CHECK-NOT: ^bb
427-
func.func @nested() {
428-
return
427+
builtin.module {
428+
func.func @nested() {
429+
return
429430

430-
^bb1:
431-
return
431+
^bb1:
432+
return
433+
}
432434
}
433435
return
434-
435-
^bb1:
436-
return
437436
}
438437

439438
// CHECK-LABEL: func @dyn_shape_fold(%arg0: index, %arg1: index)

mlir/test/Transforms/constant-fold.mlir

+7-4
Original file line numberDiff line numberDiff line change
@@ -756,12 +756,15 @@ func.func @cmpf_inf() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
756756

757757
// CHECK-LABEL: func @nested_isolated_region
758758
func.func @nested_isolated_region() {
759+
// CHECK-NEXT: builtin.module {
759760
// CHECK-NEXT: func @isolated_op
760761
// CHECK-NEXT: arith.constant 2
761-
func.func @isolated_op() {
762-
%0 = arith.constant 1 : i32
763-
%2 = arith.addi %0, %0 : i32
764-
"foo.yield"(%2) : (i32) -> ()
762+
builtin.module {
763+
func.func @isolated_op() {
764+
%0 = arith.constant 1 : i32
765+
%2 = arith.addi %0, %0 : i32
766+
"foo.yield"(%2) : (i32) -> ()
767+
}
765768
}
766769

767770
// CHECK: "foo.unknown_region"

mlir/test/Transforms/cse.mlir

+7-4
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,14 @@ func.func @nested_isolated() -> i32 {
228228
// CHECK-NEXT: arith.constant 1
229229
%0 = arith.constant 1 : i32
230230

231+
// CHECK-NEXT: builtin.module
231232
// CHECK-NEXT: @nested_func
232-
func.func @nested_func() {
233-
// CHECK-NEXT: arith.constant 1
234-
%foo = arith.constant 1 : i32
235-
"foo.yield"(%foo) : (i32) -> ()
233+
builtin.module {
234+
func.func @nested_func() {
235+
// CHECK-NEXT: arith.constant 1
236+
%foo = arith.constant 1 : i32
237+
"foo.yield"(%foo) : (i32) -> ()
238+
}
236239
}
237240

238241
// CHECK: "foo.region"

mlir/test/Transforms/test-legalizer-full.mlir

+5-3
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@ func.func @recursively_legal_invalid_op() {
3737
}
3838
/// Operation that is dynamically legal, i.e. the function has a pattern
3939
/// applied to legalize the argument type before it becomes recursively legal.
40-
func.func @dynamic_func(%arg: i64) attributes {test.recursively_legal} {
41-
%ignored = "test.illegal_op_f"() : () -> (i32)
42-
"test.return"() : () -> ()
40+
builtin.module {
41+
func.func @dynamic_func(%arg: i64) attributes {test.recursively_legal} {
42+
%ignored = "test.illegal_op_f"() : () -> (i32)
43+
"test.return"() : () -> ()
44+
}
4345
}
4446

4547
"test.return"() : () -> ()

mlir/test/python/ir/value.py

+4-32
Original file line numberDiff line numberDiff line change
@@ -167,28 +167,15 @@ def testValuePrintAsOperand():
167167
print(value2)
168168

169169
topFn = func.FuncOp("test", ([i32, i32], []))
170-
entry_block1 = Block.create_at_start(topFn.operation.regions[0], [i32, i32])
170+
entry_block = Block.create_at_start(topFn.operation.regions[0], [i32, i32])
171171

172-
with InsertionPoint(entry_block1):
172+
with InsertionPoint(entry_block):
173173
value3 = Operation.create("custom.op3", results=[i32]).results[0]
174174
# CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32)
175175
print(value3)
176176
value4 = Operation.create("custom.op4", results=[i32]).results[0]
177177
# CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32)
178178
print(value4)
179-
180-
f = func.FuncOp("test", ([i32, i32], []))
181-
entry_block2 = Block.create_at_start(f.operation.regions[0], [i32, i32])
182-
with InsertionPoint(entry_block2):
183-
value5 = Operation.create("custom.op5", results=[i32]).results[0]
184-
# CHECK: Value(%[[VAL5:.*]] = "custom.op5"() : () -> i32)
185-
print(value5)
186-
value6 = Operation.create("custom.op6", results=[i32]).results[0]
187-
# CHECK: Value(%[[VAL6:.*]] = "custom.op6"() : () -> i32)
188-
print(value6)
189-
190-
func.ReturnOp([])
191-
192179
func.ReturnOp([])
193180

194181
# CHECK: %[[VAL1]]
@@ -215,32 +202,17 @@ def testValuePrintAsOperand():
215202
# CHECK: %1
216203
print(value4.get_name(use_local_scope=True))
217204

218-
# CHECK: %[[VAL5]]
219-
print(value5.get_name())
220-
# CHECK: %[[VAL6]]
221-
print(value6.get_name())
222-
223205
# CHECK: %[[ARG0:.*]]
224-
print(entry_block1.arguments[0].get_name())
206+
print(entry_block.arguments[0].get_name())
225207
# CHECK: %[[ARG1:.*]]
226-
print(entry_block1.arguments[1].get_name())
227-
228-
# CHECK: %[[ARG2:.*]]
229-
print(entry_block2.arguments[0].get_name())
230-
# CHECK: %[[ARG3:.*]]
231-
print(entry_block2.arguments[1].get_name())
208+
print(entry_block.arguments[1].get_name())
232209

233210
# CHECK: module {
234211
# CHECK: %[[VAL1]] = "custom.op1"() : () -> i32
235212
# CHECK: %[[VAL2]] = "custom.op2"() : () -> i32
236213
# CHECK: func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) {
237214
# CHECK: %[[VAL3]] = "custom.op3"() : () -> i32
238215
# CHECK: %[[VAL4]] = "custom.op4"() : () -> i32
239-
# CHECK: func @test(%[[ARG2]]: i32, %[[ARG3]]: i32) {
240-
# CHECK: %[[VAL5]] = "custom.op5"() : () -> i32
241-
# CHECK: %[[VAL6]] = "custom.op6"() : () -> i32
242-
# CHECK: return
243-
# CHECK: }
244216
# CHECK: return
245217
# CHECK: }
246218
# CHECK: }

0 commit comments

Comments
 (0)