Skip to content

Commit 03881dc

Browse files
authored
[mlir][emitc] Add a declare_func operation (#80297)
This adds the `emitc.declare_func` operation that allows to emit the declaration of an `emitc.func` at a specific location.
1 parent 47a12cc commit 03881dc

File tree

6 files changed

+127
-6
lines changed

6 files changed

+127
-6
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

+42
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,48 @@ def EmitC_CallOp : EmitC_Op<"call",
460460
}];
461461
}
462462

463+
def EmitC_DeclareFuncOp : EmitC_Op<"declare_func", [
464+
DeclareOpInterfaceMethods<SymbolUserOpInterface>
465+
]> {
466+
let summary = "An operation to declare a function";
467+
let description = [{
468+
The `declare_func` operation allows to insert a function declaration for an
469+
`emitc.func` at a specific position. The operation only requires the `callee`
470+
of the `emitc.func` to be specified as an attribute.
471+
472+
Example:
473+
474+
```mlir
475+
emitc.declare_func @bar
476+
emitc.func @foo(%arg0: i32) -> i32 {
477+
%0 = emitc.call @bar(%arg0) : (i32) -> (i32)
478+
emitc.return %0 : i32
479+
}
480+
481+
emitc.func @bar(%arg0: i32) -> i32 {
482+
emitc.return %arg0 : i32
483+
}
484+
```
485+
486+
```c++
487+
// Code emitted for the operations above.
488+
int32_t bar(int32_t v1);
489+
int32_t foo(int32_t v1) {
490+
int32_t v2 = bar(v1);
491+
return v2;
492+
}
493+
494+
int32_t bar(int32_t v1) {
495+
return v1;
496+
}
497+
```
498+
}];
499+
let arguments = (ins FlatSymbolRefAttr:$sym_name);
500+
let assemblyFormat = [{
501+
$sym_name attr-dict
502+
}];
503+
}
504+
463505
def EmitC_FuncOp : EmitC_Op<"func", [
464506
AutomaticAllocationScope,
465507
FunctionOpInterface, IsolatedFromAbove

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,24 @@ FunctionType CallOp::getCalleeType() {
393393
return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
394394
}
395395

396+
//===----------------------------------------------------------------------===//
397+
// DeclareFuncOp
398+
//===----------------------------------------------------------------------===//
399+
400+
LogicalResult
401+
DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
402+
// Check that the sym_name attribute was specified.
403+
auto fnAttr = getSymNameAttr();
404+
if (!fnAttr)
405+
return emitOpError("requires a 'sym_name' symbol reference attribute");
406+
FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
407+
if (!fn)
408+
return emitOpError() << "'" << fnAttr.getValue()
409+
<< "' does not reference a valid function";
410+
411+
return success();
412+
}
413+
396414
//===----------------------------------------------------------------------===//
397415
// FuncOp
398416
//===----------------------------------------------------------------------===//

mlir/lib/Target/Cpp/TranslateToCpp.cpp

+39-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/BuiltinTypes.h"
1515
#include "mlir/IR/Dialect.h"
1616
#include "mlir/IR/Operation.h"
17+
#include "mlir/IR/SymbolTable.h"
1718
#include "mlir/Support/IndentedOstream.h"
1819
#include "mlir/Support/LLVM.h"
1920
#include "mlir/Target/Cpp/CppEmitter.h"
@@ -855,8 +856,9 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
855856
// needs to be printed after the closing brace.
856857
// When generating code for an emitc.for and emitc.verbatim op, printing a
857858
// trailing semicolon is handled within the printOperation function.
858-
bool trailingSemicolon = !isa<cf::CondBranchOp, emitc::ForOp, emitc::IfOp,
859-
emitc::LiteralOp, emitc::VerbatimOp>(op);
859+
bool trailingSemicolon =
860+
!isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp,
861+
emitc::IfOp, emitc::LiteralOp, emitc::VerbatimOp>(op);
860862

861863
if (failed(emitter.emitOperation(
862864
op, /*trailingSemicolon=*/trailingSemicolon)))
@@ -938,6 +940,37 @@ static LogicalResult printOperation(CppEmitter &emitter,
938940
return success();
939941
}
940942

943+
static LogicalResult printOperation(CppEmitter &emitter,
944+
DeclareFuncOp declareFuncOp) {
945+
CppEmitter::Scope scope(emitter);
946+
raw_indented_ostream &os = emitter.ostream();
947+
948+
auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
949+
declareFuncOp, declareFuncOp.getSymNameAttr());
950+
951+
if (!functionOp)
952+
return failure();
953+
954+
if (functionOp.getSpecifiers()) {
955+
for (Attribute specifier : functionOp.getSpecifiersAttr()) {
956+
os << cast<StringAttr>(specifier).str() << " ";
957+
}
958+
}
959+
960+
if (failed(emitter.emitTypes(functionOp.getLoc(),
961+
functionOp.getFunctionType().getResults())))
962+
return failure();
963+
os << " " << functionOp.getName();
964+
965+
os << "(";
966+
Operation *operation = functionOp.getOperation();
967+
if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
968+
return failure();
969+
os << ");";
970+
971+
return success();
972+
}
973+
941974
CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
942975
: os(os), declareVariablesAtTop(declareVariablesAtTop) {
943976
valueInScopeCount.push(0);
@@ -1251,10 +1284,10 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
12511284
// EmitC ops.
12521285
.Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, emitc::CallOp,
12531286
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
1254-
emitc::ConstantOp, emitc::DivOp, emitc::ExpressionOp,
1255-
emitc::ForOp, emitc::FuncOp, emitc::IfOp, emitc::IncludeOp,
1256-
emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
1257-
emitc::VariableOp, emitc::VerbatimOp>(
1287+
emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp,
1288+
emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
1289+
emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
1290+
emitc::SubOp, emitc::VariableOp, emitc::VerbatimOp>(
12581291
[&](auto op) { return printOperation(*this, op); })
12591292
// Func ops.
12601293
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(

mlir/test/Dialect/EmitC/invalid_ops.mlir

+10
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,13 @@ func.func @return_inside_func.func(%0: i32) -> (i32) {
321321

322322
// expected-error@+1 {{expected non-function type}}
323323
emitc.func @func_variadic(...)
324+
325+
// -----
326+
327+
// expected-error@+1 {{'emitc.declare_func' op 'bar' does not reference a valid function}}
328+
emitc.declare_func @bar
329+
330+
// -----
331+
332+
// expected-error@+1 {{'emitc.declare_func' op requires attribute 'sym_name'}}
333+
"emitc.declare_func"() : () -> ()

mlir/test/Dialect/EmitC/ops.mlir

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ func.func @f(%arg0: i32, %f: !emitc.opaque<"int32_t">) {
1515
return
1616
}
1717

18+
emitc.declare_func @func
19+
1820
emitc.func @func(%arg0 : i32) {
1921
emitc.call_opaque "foo"(%arg0) : (i32) -> ()
2022
emitc.return
+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
2+
3+
// CHECK: int32_t bar(int32_t [[V1:[^ ]*]]);
4+
emitc.declare_func @bar
5+
// CHECK: int32_t bar(int32_t [[V1:[^ ]*]]) {
6+
emitc.func @bar(%arg0: i32) -> i32 {
7+
emitc.return %arg0 : i32
8+
}
9+
10+
11+
// CHECK: static inline int32_t foo(int32_t [[V1:[^ ]*]]);
12+
emitc.declare_func @foo
13+
// CHECK: static inline int32_t foo(int32_t [[V1:[^ ]*]]) {
14+
emitc.func @foo(%arg0: i32) -> i32 attributes {specifiers = ["static","inline"]} {
15+
emitc.return %arg0 : i32
16+
}

0 commit comments

Comments
 (0)