From 041ed80a9d589c66fe1682d6c3bebe8209603698 Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Fri, 31 Jan 2025 23:30:31 -0800 Subject: [PATCH 1/2] Translate action calls Signed-off-by: Anton Korobeynikov --- include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td | 90 +++++++++++++++- lib/Dialect/P4HIR/P4HIR_Ops.cpp | 36 +++++++ test/Dialect/P4HIR/call.mlir | 27 +++++ test/Translate/Ops/calls.p4 | 53 +++++++++ tools/p4mlir-translate/main.cpp | 2 + tools/p4mlir-translate/translate.cpp | 126 ++++++++++++++++++++-- tools/p4mlir-translate/translate.h | 3 +- 7 files changed, 319 insertions(+), 18 deletions(-) create mode 100644 test/Dialect/P4HIR/call.mlir create mode 100644 test/Translate/Ops/calls.p4 diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td index 39a0667..e5e61c3 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td @@ -175,6 +175,7 @@ def StoreOp : P4HIR_Op<"store", [ // TODO: Decide if we'd want to be more precise and split cast into // bitcast, trunc and extensions +// TODO: Add CastOpInterface def CastOp : P4HIR_Op<"cast", [Pure/*, DeclareOpInterfaceMethods*/]> { @@ -385,7 +386,6 @@ def YieldOp : P4HIR_Op<"yield", [ReturnLike, Terminator, ParentOneOf<["ScopeOp", "TernaryOp", "IfOp", // "SwitchOp", "CaseOp", // "ForInOp", "ForOp", - // "CallOp" ]>]> { let summary = "Represents the default branching behaviour of a region"; let description = [{ @@ -538,7 +538,7 @@ def ReturnOp : P4HIR_Op<"return", [ParentOneOf<["ScopeOp", "IfOp", // "SwitchOp", "CaseOp", // "ForInOp", "ForOp", ]>, - Terminator]> { + Terminator, ReturnLike]> { let summary = "Return from function or action"; let description = [{ The "return" operation represents a return operation within a function or action. @@ -630,9 +630,11 @@ def ActionOp : P4HIR_Op<"action", [ let regions = (region AnyRegion:$body); let skipDefaultBuilders = 1; - let builders = [OpBuilder<(ins "llvm::StringRef":$name, "ActionType":$type, - CArg<"llvm::ArrayRef", "{}">:$attrs, - CArg<"llvm::ArrayRef", "{}">:$argAttrs)>]; + let builders = [ + OpBuilder<(ins "llvm::StringRef":$name, "ActionType":$type, + CArg<"llvm::ArrayRef", "{}">:$attrs, + CArg<"llvm::ArrayRef", "{}">:$argAttrs)> + ]; let extraClassDeclaration = [{ /// Returns the region on the current operation that is callable. Always @@ -682,4 +684,82 @@ def ActionOp : P4HIR_Op<"action", [ let hasVerifier = 1; } +def CallOp : P4HIR_Op<"call", + [NoRegionArguments, CallOpInterface, + DeclareOpInterfaceMethods]> { + let summary = "call operation"; + let description = [{ + The `call` operation represents a direct call to a method (action, function, etc.) that + is within the same symbol scope as the call. The operands and result types of the call must + match the specified function type. The callee is encoded as a symbol reference attribute + named "callee". + + Example: + + ```mlir + // Direct call of function + %2 = p4hir.call @my_add(%0, %1) : (!p4hir.bit<8>, !p4hir.bit<8>) -> !p4hir.bit<8> + // Direct call of action + %4 = p4hir.call @my_add(%0, %1) : (!p4hir.bit<8>, !p4hir.bit<8>) -> () + ... + ``` + }]; + + // TODO: Refine result types, refine parameter type + let results = (outs Optional:$result); + let arguments = (ins OptionalAttr:$callee, Variadic:$operands); + + let skipDefaultBuilders = 1; + let hasVerifier = 0; + + let builders = [ + // Functions + OpBuilder<(ins "mlir::SymbolRefAttr":$callee, "mlir::Type":$resType, + CArg<"mlir::ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", callee); + $_state.addTypes(resType); + }]>, + // Everything else that does not produce result + OpBuilder<(ins "mlir::SymbolRefAttr":$callee, + CArg<"mlir::ValueRange", "{}">:$operands), [{ + $_state.addOperands(operands); + $_state.addAttribute("callee", callee); + }]> + + ]; + + let extraClassDeclaration = [{ + /// Get the argument operands to the called function. + mlir::OperandRange getArgOperands() { + return {arg_operand_begin(), arg_operand_end()}; + } + + mlir::MutableOperandRange getArgOperandsMutable() { + return getOperandsMutable(); + } + + operand_iterator arg_operand_begin() { return operand_begin(); } + operand_iterator arg_operand_end() { return operand_end(); } + + /// Return the callee of this operation. + ::mlir::CallInterfaceCallable getCallableForCallee() { + return (*this)->getAttrOfType("callee"); + } + + /// Set the callee for this operation. + void setCalleeFromCallable(mlir::CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); + } + + void setArg(unsigned index, mlir::Value value) { + setOperand(index, value); + } + }]; + + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; +} + #endif // P4MLIR_DIALECT_P4HIR_P4HIR_OPS_TD diff --git a/lib/Dialect/P4HIR/P4HIR_Ops.cpp b/lib/Dialect/P4HIR/P4HIR_Ops.cpp index 8afc496..c00f1df 100644 --- a/lib/Dialect/P4HIR/P4HIR_Ops.cpp +++ b/lib/Dialect/P4HIR/P4HIR_Ops.cpp @@ -430,6 +430,42 @@ ParseResult P4HIR::ActionOp::parse(OpAsmParser &parser, OperationState &state) { return success(); } +LogicalResult P4HIR::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the callee attribute was specified. + auto fnAttr = (*this)->getAttrOfType("callee"); + if (!fnAttr) return emitOpError("requires a 'callee' symbol reference attribute"); + ActionOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + if (!fn) + return emitOpError() << "'" << fnAttr.getValue() << "' does not reference a valid function"; + + // Verify that the operand and result types match the callee. + auto fnType = fn.getFunctionType(); + if (fnType.getNumInputs() != getNumOperands()) + return emitOpError("incorrect number of operands for callee"); + + for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) + if (getOperand(i).getType() != fnType.getInput(i)) + return emitOpError("operand type mismatch: expected operand type ") + << fnType.getInput(i) << ", but provided " << getOperand(i).getType() + << " for operand number " << i; + + // FIXME: Generalize to functions + if (0 != getNumResults()) return emitOpError("incorrect number of results for callee"); + + /* + if (fnType.getNumResults( != getNumResults()) + return emitOpError("incorrect number of results for callee"); + + for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) + if (getResult(i).getType() != fnType.getResult(i)) { + auto diag = emitOpError("result type mismatch at index ") << i; + diag.attachNote() << " op result types: " << getResultTypes(); + diag.attachNote() << "function result types: " << fnType.getResults(); + return diag; + }*/ + + return success(); +} void P4HIR::P4HIRDialect::initialize() { registerTypes(); registerAttributes(); diff --git a/test/Dialect/P4HIR/call.mlir b/test/Dialect/P4HIR/call.mlir new file mode 100644 index 0000000..89659c2 --- /dev/null +++ b/test/Dialect/P4HIR/call.mlir @@ -0,0 +1,27 @@ +// RUN: p4mlir-opt %s | FileCheck %s + +!bit32 = !p4hir.bit<32> + +// CHECK: module +// CHECK-LABEL: p4hir.action @foo(%arg0: !p4hir.ref> {p4hir.dir = #p4hir}, %arg1: !p4hir.bit<32> {p4hir.dir = #p4hir}, %arg2: !p4hir.ref> {p4hir.dir = #p4hir}, %arg3: !p4hir.int<42>) { +p4hir.action @foo(%arg0 : !p4hir.ref {p4hir.dir = #p4hir}, + %arg1 : !bit32 {p4hir.dir = #p4hir}, + %arg2 : !p4hir.ref {p4hir.dir = #p4hir}, + %arg3 : !p4hir.int<42>) { + %0 = p4hir.alloca !bit32 ["tmp"] : !p4hir.ref + %1 = p4hir.load %arg0 : !p4hir.ref, !bit32 + + p4hir.store %arg1, %0 : !bit32, !p4hir.ref + p4hir.store %1, %arg2 : !bit32, !p4hir.ref + + p4hir.return +} + +p4hir.action @bar() { + %0 = p4hir.alloca !bit32 ["tmp"] : !p4hir.ref + %1 = p4hir.load %0 : !p4hir.ref, !bit32 + %3 = p4hir.const #p4hir.int<7> : !p4hir.int<42> + p4hir.call @foo(%0, %1, %0, %3) : (!p4hir.ref, !bit32, !p4hir.ref, !p4hir.int<42>) -> () + + p4hir.return +} diff --git a/test/Translate/Ops/calls.p4 b/test/Translate/Ops/calls.p4 new file mode 100644 index 0000000..25ddcea --- /dev/null +++ b/test/Translate/Ops/calls.p4 @@ -0,0 +1,53 @@ +// RUN: p4mlir-translate --typeinference-only %s | FileCheck %s + +// CHECK-LABEL: foo +action foo(in int<16> arg2, bit<10> arg1) { + int<16> x1 = 3; + x1 = arg2; + bit<10> x2 = arg1; +} + +// CHECK-LABEL: bar +action bar() { + int<16> x1 = 2; + return; +} + +// CHECK-LABEL: baz +action baz(inout int<16> x) { + x = x + 1; + return; +} + +// CHECK-LABEL: quuz +action quuz(out int<16> a) { + a = 42; +} + +// CHECK-LABEL: bazz +action bazz(in int<16> arg1) { + // CHECK: p4hir.call @foo(%arg0, %{{.*}}) : (!p4hir.int<16>, !p4hir.bit<10>) -> () + foo(arg1, 7); + bit<10> x1 = 5; + // CHECK: p4hir.call @foo(%arg0, %{{.*}}) : (!p4hir.int<16>, !p4hir.bit<10>) -> () + foo(arg1, x1); + // CHECK: p4hir.call @foo(%{{.*}}, %{{.*}}) : (!p4hir.int<16>, !p4hir.bit<10>) -> () + foo(4, 2); + // CHECK: p4hir.call @bar() : () -> () + bar(); + // CHECK: p4hir.scope + // CHECK: %[[VAR_A:.*]] = p4hir.alloca !p4hir.int<16> ["a"] : !p4hir.ref> + // CHECK: p4hir.call @quuz(%[[VAR_A]]) : (!p4hir.ref>) -> () + // CHECK: p4hir.load %[[VAR_A]] : !p4hir.ref>, !p4hir.int<16> + int<16> val; + quuz(val); + // CHECK: p4hir.scope + // CHECK: %[[VAR_X:.*]] = p4hir.alloca !p4hir.int<16> ["x", init] : !p4hir.ref> + // CHECK: %[[VAL_X:.*]] = p4hir.load %[[VAL:.*]] : !p4hir.ref>, !p4hir.int<16> + // CHECK: p4hir.store %[[VAL_X]], %[[VAR_X]] : !p4hir.int<16>, !p4hir.ref> + // CHECK: p4hir.call @baz(%[[VAR_X]]) + // CHECK: %[[OUT_X:.*]] = p4hir.load %[[VAR_X]] : !p4hir.ref>, !p4hir.int<16> + // CHECK: p4hir.store %[[OUT_X]], %[[VAL]] : !p4hir.int<16>, !p4hir.ref> + baz(val); + return; +} diff --git a/tools/p4mlir-translate/main.cpp b/tools/p4mlir-translate/main.cpp index c07aa3b..d695259 100644 --- a/tools/p4mlir-translate/main.cpp +++ b/tools/p4mlir-translate/main.cpp @@ -129,6 +129,8 @@ int main(int argc, char *const argv[]) { } } + if (P4::errorCount() > 0) return EXIT_FAILURE; + BUG_CHECK(options.typeinferenceOnly, "TODO: fill TypeMap"); log_dump(program, "After frontend"); diff --git a/tools/p4mlir-translate/translate.cpp b/tools/p4mlir-translate/translate.cpp index 7bedb57..b3e10f3 100644 --- a/tools/p4mlir-translate/translate.cpp +++ b/tools/p4mlir-translate/translate.cpp @@ -1,19 +1,23 @@ #include "translate.h" +#include #include +#include "ir/ir-generated.h" +#include "p4mlir/third_party/llvm-project/mlir/include/mlir/IR/Value.h" + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wcovered-switch-default" #include "frontends/common/resolveReferences/resolveReferences.h" +#include "frontends/p4/methodInstance.h" #include "frontends/p4/typeMap.h" -#include "ir/ir-generated.h" #include "ir/ir.h" #include "ir/visitor.h" #include "lib/big_int.h" #include "lib/indent.h" #include "lib/log.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/BuiltinAttributes.h" +#pragma GCC diagnostic pop + #include "p4mlir/Dialect/P4HIR/P4HIR_Attrs.h" #include "p4mlir/Dialect/P4HIR/P4HIR_Dialect.h" #include "p4mlir/Dialect/P4HIR/P4HIR_Ops.h" @@ -23,9 +27,14 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfoVariant.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" @@ -136,7 +145,7 @@ class P4TypeConverter : public P4::Inspector { class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext { mlir::OpBuilder &builder; - const P4::TypeMap *typeMap = nullptr; + P4::TypeMap *typeMap = nullptr; llvm::DenseMap p4Types; // TODO: Implement unified constant map // using CTVOrExpr = std::variant p4Constants; llvm::DenseMap p4Constants; llvm::DenseMap p4Values; + using P4Symbol = std::variant; + // TODO: Implement better scoped symbol table + llvm::DenseMap p4Symbols; mlir::TypedAttr resolveConstant(const P4::IR::CompileTimeValue *ctv); mlir::TypedAttr resolveConstantExpr(const P4::IR::Expression *expr); @@ -156,7 +168,7 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext { } public: - P4HIRConverter(mlir::OpBuilder &builder, const P4::TypeMap *typeMap) + P4HIRConverter(mlir::OpBuilder &builder, P4::TypeMap *typeMap) : builder(builder), typeMap(typeMap) { CHECK_NULL(typeMap); } @@ -340,6 +352,12 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext { bool preorder(const P4::IR::LOr *lor) override; bool preorder(const P4::IR::LAnd *land) override; bool preorder(const P4::IR::IfStatement *ifs) override; + bool preorder(const P4::IR::MethodCallStatement *) override { + // We handle MethodCallExpression instead + return true; + } + + bool preorder(const P4::IR::MethodCallExpression *mce) override; mlir::Value emitUnOp(const P4::IR::Operation_Unary *unop, P4HIR::UnaryOpKind kind); mlir::Value emitBinOp(const P4::IR::Operation_Binary *binop, P4HIR::BinOpKind kind); @@ -379,11 +397,13 @@ bool P4TypeConverter::preorder(const P4::IR::Type_Unknown *type) { } bool P4TypeConverter::preorder(const P4::IR::Type_Name *name) { + if ((this->type = converter.findType(name))) return false; + ConversionTracer trace("TypeConverting ", name); const auto *type = converter.resolveType(name); CHECK_NULL(type); - visit(type); - return false; + mlir::Type mlirType = convert(type); + return setType(name, mlirType); } bool P4TypeConverter::preorder(const P4::IR::Type_Action *type) { @@ -674,6 +694,9 @@ bool P4HIRConverter::preorder(const P4::IR::IfStatement *ifs) { bool P4HIRConverter::preorder(const P4::IR::P4Action *act) { ConversionTracer trace("Converting ", act); + // TODO: Actions might reference some control locals, we need to make + // them visible somehow (e.g. via additional arguments) + // FIXME: Get rid of typeMap: ensure action knows its type auto actType = mlir::cast(getOrCreateType(typeMap->getType(act, true))); const auto ¶ms = act->getParameters()->parameters; @@ -731,10 +754,15 @@ bool P4HIRConverter::preorder(const P4::IR::P4Action *act) { } } + auto [it, inserted] = p4Symbols.try_emplace(act, mlir::SymbolRefAttr::get(action)); + BUG_CHECK(inserted, "duplicate translation of %1%", act); + return false; } void P4HIRConverter::postorder(const P4::IR::ReturnStatement *ret) { + ConversionTracer trace("Converting ", ret); + // TODO: ReturnOp is a terminator, so it cannot be in the middle of block; // ensure nothing is created afterwards if (ret->expression) { @@ -745,13 +773,89 @@ void P4HIRConverter::postorder(const P4::IR::ReturnStatement *ret) { } } +bool P4HIRConverter::preorder(const P4::IR::MethodCallExpression *mce) { + ConversionTracer trace("Converting ", mce); + const auto *instance = + P4::MethodInstance::resolve(mce, this, typeMap, false, getChildContext()); + const auto ¶ms = instance->originalMethodType->parameters->parameters; + + // TODO: Actions might have some parameters coming from control plane + + // Prepare call arguments. Note that this involves creating temporaries to + // model copy-in/out semantics. To limit the lifetime of those temporaries, do + // everything in the dedicated block scope. If there are no out parameters, + // then emit everything direct. + auto convertCall = [&](mlir::OpBuilder &b, mlir::Location loc) { + llvm::SmallVector operands; + for (auto [idx, arg] : llvm::enumerate(*mce->arguments)) { + ConversionTracer trace("Converting ", arg); + visit(arg->expression); + mlir::Value argVal; + switch (auto dir = params[idx]->direction) { + case P4::IR::Direction::None: + case P4::IR::Direction::In: + // Nothing to do special, just pass things direct + argVal = getValue(arg->expression); + break; + case P4::IR::Direction::Out: + case P4::IR::Direction::InOut: { + // Just create temporary to hold the output value, initialize in case of inout + auto ref = resolveReference(arg->expression); + auto type = mlir::cast(ref.getType()); + + auto copyIn = b.create(loc, type, type.getObjectType(), + params[idx]->name.string_view()); + + if (dir == P4::IR::Direction::InOut) { + copyIn.setInit(true); + b.create(loc, getValue(arg->expression), copyIn); + } + argVal = copyIn; + break; + } + } + operands.push_back(argVal); + } + + if (const auto *actCall = instance->to()) { + auto actSym = p4Symbols.lookup(actCall->action); + BUG_CHECK(actSym, "expected reference action to be converted: %1%", actCall->action); + + b.create(loc, actSym, operands); + } else { + BUG("unsupported call type: %1%", instance); + } + + for (auto [idx, arg] : llvm::enumerate(*mce->arguments)) { + // Determine the direction of the parameter + if (!params[idx]->hasOut()) continue; + + mlir::Value copyOut = operands[idx]; + mlir::Value dest = resolveReference(arg->expression); + b.create( + getEndLoc(builder, mce), + builder.create(getEndLoc(builder, mce), copyOut), dest); + } + }; + + if (std::any_of(params.begin(), params.end(), [](const auto *p) { return p->hasOut(); })) { + mlir::OpBuilder::InsertionGuard guard(builder); + auto scope = builder.create(getLoc(builder, mce), convertCall); + builder.setInsertionPointToEnd(&scope.getScopeRegion().back()); + builder.create(getEndLoc(builder, mce)); + } else { + convertCall(builder, getLoc(builder, mce)); + } + + return false; +} + } // namespace namespace P4::P4MLIR { mlir::OwningOpRef toMLIR(mlir::MLIRContext &context, - const P4::IR::P4Program *program, - const P4::TypeMap *typeMap) { + const P4::IR::P4Program *program, P4::TypeMap *typeMap) { mlir::OpBuilder builder(&context); auto moduleOp = mlir::ModuleOp::create(builder.getUnknownLoc()); diff --git a/tools/p4mlir-translate/translate.h b/tools/p4mlir-translate/translate.h index 5d1042c..9219908 100644 --- a/tools/p4mlir-translate/translate.h +++ b/tools/p4mlir-translate/translate.h @@ -13,6 +13,5 @@ class TypeMap; namespace P4::P4MLIR { mlir::OwningOpRef toMLIR(mlir::MLIRContext &context, - const P4::IR::P4Program *program, - const P4::TypeMap *typeMap); + const P4::IR::P4Program *program, P4::TypeMap *typeMap); } // namespace P4::P4MLIR From 1f7501c240034e7b9b258a844faa87e6a48c720e Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Tue, 4 Feb 2025 00:15:55 -0800 Subject: [PATCH 2/2] Generalize actions to functions. Implement function lowering Signed-off-by: Anton Korobeynikov --- include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td | 75 ++++++-- include/p4mlir/Dialect/P4HIR/P4HIR_Types.td | 61 +++++-- lib/Dialect/P4HIR/P4HIR_Ops.cpp | 158 +++++++++++----- lib/Dialect/P4HIR/P4HIR_Types.cpp | 71 ++++++-- test/Dialect/P4HIR/action.mlir | 10 +- test/Dialect/P4HIR/call.mlir | 12 +- test/Dialect/P4HIR/types.mlir | 4 +- test/Translate/Ops/action.p4 | 4 +- test/Translate/Ops/assign.p4 | 2 +- test/Translate/Ops/binop.p4 | 4 +- test/Translate/Ops/calls.p4 | 4 +- test/Translate/Ops/cmp.p4 | 2 +- test/Translate/Ops/function.p4 | 64 +++++++ test/Translate/Ops/scope.p4 | 2 +- test/Translate/Ops/unop.p4 | 2 +- test/Translate/Ops/variables.p4 | 2 +- tools/p4mlir-translate/translate.cpp | 189 ++++++++++++++++---- 17 files changed, 521 insertions(+), 145 deletions(-) create mode 100644 test/Translate/Ops/function.p4 diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td index e5e61c3..cf07029 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Ops.td @@ -96,7 +96,7 @@ def AllocaOp : P4HIR_Op<"alloca", [ let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "mlir::Type":$ref, "mlir::Type":$objectType, "llvm::StringRef":$name)>, + OpBuilder<(ins "mlir::Type":$ref, "mlir::Type":$objectType, "const llvm::Twine &":$name)>, ]; let assemblyFormat = [{ @@ -534,11 +534,16 @@ def IfOp : P4HIR_Op<"if", } def ReturnOp : P4HIR_Op<"return", [ParentOneOf<["ScopeOp", "IfOp", - "ActionOp", + "FuncOp", // "SwitchOp", "CaseOp", // "ForInOp", "ForOp", ]>, - Terminator, ReturnLike]> { + // Note that ReturnOp is not ReturnLike: currently there is no way to + // represent early exits in MLIR "properly" + // We might not be able to have it a Terminator at this level in order + // to represent dead code. We might lower it to proper terminator later (!) + // See https://discourse.llvm.org/t/rfc-region-based-control-flow-with-early-exits-in-mlir/76998 + Terminator]> { let summary = "Return from function or action"; let description = [{ The "return" operation represents a return operation within a function or action. @@ -579,13 +584,12 @@ def ReturnOp : P4HIR_Op<"return", [ParentOneOf<["ScopeOp", "IfOp", let hasVerifier = 1; } -def ActionOp : P4HIR_Op<"action", [ +def FuncOp : P4HIR_Op<"func", [ AutomaticAllocationScope, CallableOpInterface, FunctionOpInterface, IsolatedFromAbove ]> { - let summary = "Define an action"; + let summary = "Define a function-like object (action, function)"; let description = [{ - Similar to `mlir::FuncOp` built-in: > Operations within the function cannot implicitly capture values defined > outside of the function, i.e. Functions are `IsolatedFromAbove`. All @@ -600,10 +604,13 @@ def ActionOp : P4HIR_Op<"action", [ > Only dialect attribute names may be specified in the attribute dictionaries > for function arguments, results, or the function itself. - Action parameters might have direction that is specified via `p4hir.dir` + Parameters might have direction that is specified via `p4hir.dir` attribute. Out and inout parameters must have a reference type. All refence-typed parameters must have a direction and it should be `out` or `input`. + An action must be marked as `action`, should always have a body and cannot return + anything. + Example: ```mlir @@ -624,22 +631,55 @@ def ActionOp : P4HIR_Op<"action", [ }]; let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$function_type, + TypeAttrOf:$function_type, + UnitAttr:$action, OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$annotations); let regions = (region AnyRegion:$body); let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "llvm::StringRef":$name, "ActionType":$type, + OpBuilder<(ins "llvm::StringRef":$name, "FuncType":$type, CArg<"llvm::ArrayRef", "{}">:$attrs, CArg<"llvm::ArrayRef", "{}">:$argAttrs)> ]; let extraClassDeclaration = [{ + // TODO: move to custom builder + static FuncOp buildAction(mlir::OpBuilder &builder, + mlir::Location loc, + llvm::StringRef name, + P4HIR::FuncType type, + llvm::ArrayRef attrs = {}, + llvm::ArrayRef argAttrs = {}) { + auto op = builder.create(loc, name, type, attrs, argAttrs); + op.createEntryBlock(); + op.setAction(true); + return op; + } + /// Returns the region on the current operation that is callable. Always /// non-null for actions. - mlir::Region *getCallableRegion(); + mlir::Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } + + /// Returns the results types that the callable region produces when + /// executed. + llvm::ArrayRef getCallableResults() { + return getFunctionType().getReturnTypes(); + } + + /// Returns the argument attributes for all callable region arguments or + /// null if there are none. + ::mlir::ArrayAttr getCallableArgAttrs() { + return getArgAttrs().value_or(nullptr); + } + + /// Returns the result attributes for all callable region results or null if + /// there are none. + ::mlir::ArrayAttr getCallableResAttrs() { + return getResAttrs().value_or(nullptr); + } /// Returns the argument types of this action. llvm::ArrayRef getArgumentTypes() { @@ -667,10 +707,10 @@ def ActionOp : P4HIR_Op<"action", [ return ParamDirection::None; } - /// Returns the result types of this action. Required for FunctionOp - /// interface, so return empty. + /// Returns 0 or 1 result type of this function (0 in the case of a function + /// returing void or action) llvm::ArrayRef getResultTypes() { - return {}; + return getFunctionType().getReturnTypes(); } /// Hook for OpTrait::FunctionOpInterfaceTrait, called after verifying that @@ -678,6 +718,10 @@ def ActionOp : P4HIR_Op<"action", [ /// Ensures getType, getNumFuncArguments, and getNumFuncResults can be /// called safely. llvm::LogicalResult verifyType(); + + bool isDeclaration() { return isExternal(); } + + void createEntryBlock(); }]; let hasCustomAssemblyFormat = 1; @@ -706,7 +750,7 @@ def CallOp : P4HIR_Op<"call", }]; // TODO: Refine result types, refine parameter type - let results = (outs Optional:$result); + let results = (outs Optional:$result); let arguments = (ins OptionalAttr:$callee, Variadic:$operands); let skipDefaultBuilders = 1; @@ -718,7 +762,8 @@ def CallOp : P4HIR_Op<"call", CArg<"mlir::ValueRange", "{}">:$operands), [{ $_state.addOperands(operands); $_state.addAttribute("callee", callee); - $_state.addTypes(resType); + if (resType && !isa(resType)) + $_state.addTypes(resType); }]>, // Everything else that does not produce result OpBuilder<(ins "mlir::SymbolRefAttr":$callee, diff --git a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td index 33be597..228da41 100644 --- a/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td +++ b/include/p4mlir/Dialect/P4HIR/P4HIR_Types.td @@ -100,6 +100,15 @@ def DontcareType : P4HIR_Type<"Dontcare", "dontcare"> {} def ErrorType : P4HIR_Type<"Error", "error"> {} def UnknownType : P4HIR_Type<"Unknown", "unknown"> {} +def VoidType : P4HIR_Type<"Void", "void"> { + let summary = "void type"; + let description = [{ + Represents absense of result of actions and methods, or `void` type for functions. + }]; + let extraClassDeclaration = [{ + std::string getAlias() const { return "void"; }; + }]; +} //===----------------------------------------------------------------------===// // ReferenceType //===----------------------------------------------------------------------===// @@ -130,29 +139,40 @@ def ReferenceType : P4HIR_Type<"Reference", "ref"> { let skipDefaultBuilders = 1; } -//===----------------------------------------------------------------------===// -// ParameterType -//===----------------------------------------------------------------------===// - -def P4HIR_ActionType : P4HIR_Type<"Action", "action"> { - let summary = "P4 action type"; +def FuncType : P4HIR_Type<"Func", "func"> { + let summary = "P4 function-like type (actions, methods, functions)"; let description = [{ - The `!p4hir.action` is an action type. It is essentially a list of parameter - types. There is no return type for actions. - + The `!p4hir.func` is a function type. Example: ```mlir - !p4hir.action<()> - !p4hir.action<(!p4hir.bit<32>, !p4hir.int<42>)> + !p4hir.func<()> + !p4hir.func(!p4hir.bit<32>, !p4hir.int<42>)> ``` }]; - let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs); + let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs, + "mlir::Type":$optionalReturnType); + + let builders = [ + // Construct with an actual return type or explicit !p4hir.void + TypeBuilderWithInferredContext<(ins + "llvm::ArrayRef":$inputs, "mlir::Type":$returnType), [{ + return $_get(returnType.getContext(), inputs, + mlir::isa(returnType) ? nullptr + : returnType); + }]>, + + // Construct without return type + TypeBuilder<(ins "llvm::ArrayRef":$inputs), [{ + return $_get($_ctxt, inputs, nullptr); + }]> + + ]; // Use a custom parser to handle the argument types in better way. let assemblyFormat = [{ - `<` custom($inputs) `>` + `<` custom($optionalReturnType, $inputs) `>` }]; let extraClassDeclaration = [{ @@ -162,9 +182,20 @@ def P4HIR_ActionType : P4HIR_Type<"Action", "action"> { /// Returns the number of arguments to the function. unsigned getNumInputs() const { return getInputs().size(); } + /// Returns the result type of the function as an actual return type or + /// explicit !p4hir.void + mlir::Type getReturnType() const; + + /// Returns the result type of the function as an ArrayRef, enabling better + /// integration with generic MLIR utilities. + llvm::ArrayRef getReturnTypes() const; + /// Returns a clone of this action type with the given argument /// and result types. Required for FunctionOp interface - ActionType clone(mlir::TypeRange inputs, mlir::TypeRange outputs) const; + FuncType clone(mlir::TypeRange inputs, mlir::TypeRange outputs) const; + + /// Returns whether the function returns void. + bool isVoid() const; }]; } @@ -174,6 +205,8 @@ def P4HIR_ActionType : P4HIR_Type<"Action", "action"> { def AnyP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, DontcareType, ErrorType, UnknownType]> {} +def CallResultP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType, VoidType]> {} def LoadableP4Type : AnyTypeOf<[BitsType, BooleanType, InfIntType]> {} + #endif // P4MLIR_DIALECT_P4HIR_P4HIR_TYPES_TD diff --git a/lib/Dialect/P4HIR/P4HIR_Ops.cpp b/lib/Dialect/P4HIR/P4HIR_Ops.cpp index c00f1df..9b9360f 100644 --- a/lib/Dialect/P4HIR/P4HIR_Ops.cpp +++ b/lib/Dialect/P4HIR/P4HIR_Ops.cpp @@ -3,6 +3,7 @@ #include "llvm/Support/LogicalResult.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/FunctionImplementation.h" #include "p4mlir/Dialect/P4HIR/P4HIR_Attrs.h" #include "p4mlir/Dialect/P4HIR/P4HIR_Dialect.h" @@ -63,9 +64,9 @@ LogicalResult P4HIR::UnaryOp::verify() { // AllocaOp //===----------------------------------------------------------------------===// -void P4HIR::AllocaOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, - ::mlir::Type ref, ::mlir::Type objectType, ::llvm::StringRef name) { - odsState.addAttribute(getObjectTypeAttrName(odsState.name), ::mlir::TypeAttr::get(objectType)); +void P4HIR::AllocaOp::build(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, + mlir::Type ref, mlir::Type objectType, const llvm::Twine &name) { + odsState.addAttribute(getObjectTypeAttrName(odsState.name), mlir::TypeAttr::get(objectType)); odsState.addAttribute(getNameAttrName(odsState.name), odsBuilder.getStringAttr(name)); odsState.addTypes(ref); } @@ -264,6 +265,11 @@ void P4HIR::IfOp::print(OpAsmPrinter &p) { /// Default callback for IfOp builders. void P4HIR::buildTerminatedBody(OpBuilder &builder, Location loc) { + Block *block = builder.getBlock(); + + // Region is properly terminated: nothing to do. + if (block->mightHaveTerminator()) return; + // add p4hir.yield to the end of the block builder.create(loc); } @@ -306,9 +312,22 @@ void P4HIR::IfOp::build(OpBuilder &builder, OperationState &result, Value cond, } mlir::LogicalResult P4HIR::ReturnOp::verify() { - // TODO: Implement checks: - // - If we're inside action, then there should not be any operands - // - Otherwise, we're inside function, ensure operand type matches with result type + // Returns can be present in multiple different scopes, get the + // wrapping function and start from there. + auto *fnOp = getOperation()->getParentOp(); + while (!isa(fnOp)) fnOp = fnOp->getParentOp(); + + // ReturnOps currently only have a single optional operand. + if (getNumOperands() > 1) return emitOpError() << "expects at most 1 return operand"; + + // Ensure returned type matches the function signature. + auto expectedTy = cast(fnOp).getFunctionType().getReturnType(); + auto actualTy = + (getNumOperands() == 0 ? P4HIR::VoidType::get(getContext()) : getOperand(0).getType()); + if (actualTy != expectedTy) + return emitOpError() << "returns " << actualTy << " but enclosing function returns " + << expectedTy; + return success(); } @@ -319,45 +338,58 @@ mlir::LogicalResult P4HIR::ReturnOp::verify() { // Hook for OpTrait::FunctionLike, called after verifying that the 'type' // attribute is present. This can check for preconditions of the // getNumArguments hook not failing. -LogicalResult P4HIR::ActionOp::verifyType() { +LogicalResult P4HIR::FuncOp::verifyType() { auto type = getFunctionType(); - if (!isa(type)) + if (!isa(type)) return emitOpError("requires '" + getFunctionTypeAttrName().str() + - "' attribute of action type"); + "' attribute of function type"); + if (auto rt = type.getReturnTypes(); !rt.empty() && mlir::isa(rt.front())) + return emitOpError( + "The return type for a function returning void should " + "be empty instead of an explicit !p4hir.void"); return success(); } -LogicalResult P4HIR::ActionOp::verify() { +LogicalResult P4HIR::FuncOp::verify() { // TODO: Check that all reference-typed arguments have direction indicated + // TODO: Check that actions do have body return success(); } -mlir::Region *P4HIR::ActionOp::getCallableRegion() { return &getBody(); } +void P4HIR::FuncOp::build(OpBuilder &builder, OperationState &result, llvm::StringRef name, + P4HIR::FuncType type, ArrayRef attrs, + ArrayRef argAttrs) { + result.addRegion(); -void P4HIR::ActionOp::build(OpBuilder &builder, OperationState &result, llvm::StringRef name, - P4HIR::ActionType type, ArrayRef attrs, - ArrayRef argAttrs) { result.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(type)); result.attributes.append(attrs.begin(), attrs.end()); + // We default to private visibility + result.addAttribute(SymbolTable::getVisibilityAttrName(), builder.getStringAttr("private")); - function_interface_impl::addArgAndResultAttrs( - builder, result, argAttrs, - /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(result.name), builder.getStringAttr("")); + function_interface_impl::addArgAndResultAttrs(builder, result, argAttrs, + /*resultAttrs=*/std::nullopt, + getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); +} - auto *region = result.addRegion(); - Block &first = region->emplaceBlock(); - for (auto argType : type.getInputs()) first.addArgument(argType, result.location); +void P4HIR::FuncOp::createEntryBlock() { + assert(empty() && "can only create entry block for empty function"); + Block &first = getFunctionBody().emplaceBlock(); + auto loc = getFunctionBody().getLoc(); + for (auto argType : getFunctionType().getInputs()) first.addArgument(argType, loc); } -void P4HIR::ActionOp::print(OpAsmPrinter &p) { +void P4HIR::FuncOp::print(OpAsmPrinter &p) { + if (getAction()) p << " action"; + // Print function name, signature, and control. p << ' '; p.printSymbolName(getSymName()); auto fnType = getFunctionType(); - llvm::SmallVector resultTypes; - function_interface_impl::printFunctionSignature(p, *this, fnType.getInputs(), false, {}); + function_interface_impl::printFunctionSignature(p, *this, fnType.getInputs(), false, + fnType.getReturnTypes()); if (mlir::ArrayAttr annotations = getAnnotationsAttr()) { p << ' '; @@ -367,24 +399,38 @@ void P4HIR::ActionOp::print(OpAsmPrinter &p) { function_interface_impl::printFunctionAttributes( p, *this, // These are all omitted since they are custom printed already. - {getFunctionTypeAttrName(), getArgAttrsAttrName()}); + {getFunctionTypeAttrName(), SymbolTable::getVisibilityAttrName(), getArgAttrsAttrName(), + getActionAttrName(), getResAttrsAttrName()}); // Print the body if this is not an external function. Region &body = getOperation()->getRegion(0); - p << ' '; - p.printRegion(body, /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/true); + if (!body.empty()) { + p << ' '; + p.printRegion(body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + } } -ParseResult P4HIR::ActionOp::parse(OpAsmParser &parser, OperationState &state) { +ParseResult P4HIR::FuncOp::parse(OpAsmParser &parser, OperationState &state) { llvm::SMLoc loc = parser.getCurrentLocation(); auto &builder = parser.getBuilder(); + // Parse action marker + auto actionNameAttr = getActionAttrName(state.name); + bool isAction = false; + if (::mlir::succeeded(parser.parseOptionalKeyword(actionNameAttr.strref()))) { + isAction = true; + state.addAttribute(actionNameAttr, parser.getBuilder().getUnitAttr()); + } + // Parse the name as a symbol. StringAttr nameAttr; if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), state.attributes)) return failure(); + // We default to private visibility + state.addAttribute(SymbolTable::getVisibilityAttrName(), builder.getStringAttr("private")); + llvm::SmallVector arguments; llvm::SmallVector resultAttrs; llvm::SmallVector argTypes; @@ -395,13 +441,19 @@ ParseResult P4HIR::ActionOp::parse(OpAsmParser &parser, OperationState &state) { return failure(); // Actions have no results - if (!resultTypes.empty()) + if (isAction && !resultTypes.empty()) return parser.emitError(loc, "actions should not produce any results"); + else if (resultTypes.size() > 1) + return parser.emitError(loc, "functions only supports zero or one results"); - // Build the action function type. + // Build the function type. for (auto &arg : arguments) argTypes.push_back(arg.type); - if (auto fnType = P4HIR::ActionType::get(builder.getContext(), argTypes)) { + // Fetch return type or set it to void if empty/ommited. + mlir::Type returnType = + (resultTypes.empty() ? P4HIR::VoidType::get(builder.getContext()) : resultTypes.front()); + + if (auto fnType = P4HIR::FuncType::get(argTypes, returnType)) { state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(fnType)); } else return failure(); @@ -418,14 +470,19 @@ ParseResult P4HIR::ActionOp::parse(OpAsmParser &parser, OperationState &state) { assert(resultAttrs.size() == resultTypes.size()); function_interface_impl::addArgAndResultAttrs(builder, state, arguments, resultAttrs, getArgAttrsAttrName(state.name), - builder.getStringAttr("")); + getResAttrsAttrName(state.name)); // Parse the action body. auto *body = state.addRegion(); - ParseResult parseResult = parser.parseRegion(*body, arguments, /*enableNameShadowing=*/false); - if (failed(parseResult)) return failure(); - // Body was parsed, make sure its not empty. - if (body->empty()) return parser.emitError(loc, "expected non-empty action body"); + if (OptionalParseResult parseResult = + parser.parseOptionalRegion(*body, arguments, /*enableNameShadowing=*/false); + parseResult.has_value()) { + if (failed(*parseResult)) return failure(); + // Function body was parsed, make sure its not empty. + if (body->empty()) return parser.emitError(loc, "expected non-empty function body"); + } else if (isAction) { + parser.emitError(loc, "action shall have a body"); + } return success(); } @@ -434,7 +491,7 @@ LogicalResult P4HIR::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable // Check that the callee attribute was specified. auto fnAttr = (*this)->getAttrOfType("callee"); if (!fnAttr) return emitOpError("requires a 'callee' symbol reference attribute"); - ActionOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); if (!fn) return emitOpError() << "'" << fnAttr.getValue() << "' does not reference a valid function"; @@ -449,23 +506,26 @@ LogicalResult P4HIR::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable << fnType.getInput(i) << ", but provided " << getOperand(i).getType() << " for operand number " << i; - // FIXME: Generalize to functions - if (0 != getNumResults()) return emitOpError("incorrect number of results for callee"); + // Actions must not return any results + if (fn.getAction() && getNumResults() != 0) + return emitOpError("incorrect number of results for action call"); - /* - if (fnType.getNumResults( != getNumResults()) - return emitOpError("incorrect number of results for callee"); + // Void function must not return any results. + if (fnType.isVoid() && getNumResults() != 0) + return emitOpError("callee returns void but call has results"); - for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) - if (getResult(i).getType() != fnType.getResult(i)) { - auto diag = emitOpError("result type mismatch at index ") << i; - diag.attachNote() << " op result types: " << getResultTypes(); - diag.attachNote() << "function result types: " << fnType.getResults(); - return diag; - }*/ + // Non-void function calls must return exactly one result. + if (!fnType.isVoid() && getNumResults() != 1) + return emitOpError("incorrect number of results for callee"); + + // Parent function and return value types must match. + if (!fnType.isVoid() && getResultTypes().front() != fnType.getReturnType()) + return emitOpError("result type mismatch: expected ") + << fnType.getReturnType() << ", but provided " << getResult().getType(); return success(); } + void P4HIR::P4HIRDialect::initialize() { registerTypes(); registerAttributes(); diff --git a/lib/Dialect/P4HIR/P4HIR_Types.cpp b/lib/Dialect/P4HIR/P4HIR_Types.cpp index ea8b529..85b5b0e 100644 --- a/lib/Dialect/P4HIR/P4HIR_Types.cpp +++ b/lib/Dialect/P4HIR/P4HIR_Types.cpp @@ -6,9 +6,11 @@ #include "p4mlir/Dialect/P4HIR/P4HIR_Dialect.h" #include "p4mlir/Dialect/P4HIR/P4HIR_OpsEnums.h" -static mlir::ParseResult parseActionType(mlir::AsmParser &p, llvm::SmallVector ¶ms); +static mlir::ParseResult parseFuncType(mlir::AsmParser &p, mlir::Type &optionalResultType, + llvm::SmallVector ¶ms); -static void printActionType(mlir::AsmPrinter &p, mlir::ArrayRef params); +static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalResultType, + mlir::ArrayRef params); #define GET_TYPEDEF_CLASSES #include "p4mlir/Dialect/P4HIR/P4HIR_Types.cpp.inc" @@ -67,29 +69,74 @@ void P4HIRDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &os) const }); } -ActionType ActionType::clone(TypeRange inputs, TypeRange results) const { - assert(results.size() == 0 && "expected exactly zero result type"); - return get(getContext(), llvm::to_vector(inputs)); +FuncType FuncType::clone(TypeRange inputs, TypeRange results) const { + assert(results.size() == 1 && "expected exactly one result type"); + return get(llvm::to_vector(inputs), results[0]); } -static mlir::ParseResult parseActionType(mlir::AsmParser &p, - llvm::SmallVector ¶ms) { - return p.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { +static mlir::ParseResult parseFuncType(mlir::AsmParser &p, mlir::Type &optionalReturnType, + llvm::SmallVector ¶ms) { + // Parse return type, if any + if (succeeded(p.parseOptionalLParen())) { + // If we have already a '(', the function has no return type + optionalReturnType = {}; + } else { mlir::Type type; if (p.parseType(type)) return mlir::failure(); - params.push_back(type); - return mlir::success(); - }); + if (mlir::isa(type)) + // An explicit !p4hir.void means also no return type. + optionalReturnType = {}; + else + // Otherwise use the actual type. + optionalReturnType = type; + if (p.parseLParen()) return mlir::failure(); + } + + // `(` `)` + if (succeeded(p.parseOptionalRParen())) return mlir::success(); + + if (p.parseCommaSeparatedList([&]() -> ParseResult { + mlir::Type type; + if (p.parseType(type)) return mlir::failure(); + params.push_back(type); + return mlir::success(); + })) + return mlir::failure(); return p.parseRParen(); } -static void printActionType(mlir::AsmPrinter &p, mlir::ArrayRef params) { +static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalReturnType, + mlir::ArrayRef params) { + if (optionalReturnType) p << optionalReturnType << ' '; p << '('; llvm::interleaveComma(params, p, [&p](mlir::Type type) { p.printType(type); }); p << ')'; } +// Return the actual return type or an explicit !p4hir.void if the function does +// not return anything +mlir::Type FuncType::getReturnType() const { + if (isVoid()) return P4HIR::VoidType::get(getContext()); + return static_cast(getImpl())->optionalReturnType; +} + +/// Returns the result type of the function as an ArrayRef, enabling better +/// integration with generic MLIR utilities. +llvm::ArrayRef FuncType::getReturnTypes() const { + if (isVoid()) return {}; + return static_cast(getImpl())->optionalReturnType; +} + +// Whether the function returns void +bool FuncType::isVoid() const { + auto rt = static_cast(getImpl())->optionalReturnType; + assert(!rt || !mlir::isa(rt) && + "The return type for a function returning void should be empty " + "instead of a real !p4hir.void"); + return !rt; +} + void P4HIRDialect::registerTypes() { addTypes< #define GET_TYPEDEF_LIST diff --git a/test/Dialect/P4HIR/action.mlir b/test/Dialect/P4HIR/action.mlir index 04df9ba..9ca285b 100644 --- a/test/Dialect/P4HIR/action.mlir +++ b/test/Dialect/P4HIR/action.mlir @@ -3,11 +3,11 @@ !bit32 = !p4hir.bit<32> // CHECK: module -// CHECK-LABEL: p4hir.action @foo(%arg0: !p4hir.ref> {p4hir.dir = #p4hir}, %arg1: !p4hir.bit<32> {p4hir.dir = #p4hir}, %arg2: !p4hir.ref> {p4hir.dir = #p4hir}, %arg3: !p4hir.int<42>) { -p4hir.action @foo(%arg0 : !p4hir.ref {p4hir.dir = #p4hir}, - %arg1 : !bit32 {p4hir.dir = #p4hir}, - %arg2 : !p4hir.ref {p4hir.dir = #p4hir}, - %arg3 : !p4hir.int<42>) { +// CHECK-LABEL: p4hir.func action @foo(%arg0: !p4hir.ref> {p4hir.dir = #p4hir}, %arg1: !p4hir.bit<32> {p4hir.dir = #p4hir}, %arg2: !p4hir.ref> {p4hir.dir = #p4hir}, %arg3: !p4hir.int<42>) { +p4hir.func action @foo(%arg0 : !p4hir.ref {p4hir.dir = #p4hir}, + %arg1 : !bit32 {p4hir.dir = #p4hir}, + %arg2 : !p4hir.ref {p4hir.dir = #p4hir}, + %arg3 : !p4hir.int<42>) { %0 = p4hir.alloca !bit32 ["tmp"] : !p4hir.ref %1 = p4hir.load %arg0 : !p4hir.ref, !bit32 diff --git a/test/Dialect/P4HIR/call.mlir b/test/Dialect/P4HIR/call.mlir index 89659c2..077802c 100644 --- a/test/Dialect/P4HIR/call.mlir +++ b/test/Dialect/P4HIR/call.mlir @@ -3,11 +3,11 @@ !bit32 = !p4hir.bit<32> // CHECK: module -// CHECK-LABEL: p4hir.action @foo(%arg0: !p4hir.ref> {p4hir.dir = #p4hir}, %arg1: !p4hir.bit<32> {p4hir.dir = #p4hir}, %arg2: !p4hir.ref> {p4hir.dir = #p4hir}, %arg3: !p4hir.int<42>) { -p4hir.action @foo(%arg0 : !p4hir.ref {p4hir.dir = #p4hir}, - %arg1 : !bit32 {p4hir.dir = #p4hir}, - %arg2 : !p4hir.ref {p4hir.dir = #p4hir}, - %arg3 : !p4hir.int<42>) { +// CHECK-LABEL: p4hir.func action @foo(%arg0: !p4hir.ref> {p4hir.dir = #p4hir}, %arg1: !p4hir.bit<32> {p4hir.dir = #p4hir}, %arg2: !p4hir.ref> {p4hir.dir = #p4hir}, %arg3: !p4hir.int<42>) { +p4hir.func action @foo(%arg0 : !p4hir.ref {p4hir.dir = #p4hir}, + %arg1 : !bit32 {p4hir.dir = #p4hir}, + %arg2 : !p4hir.ref {p4hir.dir = #p4hir}, + %arg3 : !p4hir.int<42>) { %0 = p4hir.alloca !bit32 ["tmp"] : !p4hir.ref %1 = p4hir.load %arg0 : !p4hir.ref, !bit32 @@ -17,7 +17,7 @@ p4hir.action @foo(%arg0 : !p4hir.ref {p4hir.dir = #p4hir}, p4hir.return } -p4hir.action @bar() { +p4hir.func action @bar() { %0 = p4hir.alloca !bit32 ["tmp"] : !p4hir.ref %1 = p4hir.load %0 : !p4hir.ref, !bit32 %3 = p4hir.const #p4hir.int<7> : !p4hir.int<42> diff --git a/test/Dialect/P4HIR/types.mlir b/test/Dialect/P4HIR/types.mlir index 9e420f0..b71d2ed 100644 --- a/test/Dialect/P4HIR/types.mlir +++ b/test/Dialect/P4HIR/types.mlir @@ -5,8 +5,8 @@ !dontcare = !p4hir.dontcare !ref = !p4hir.ref> -!action_noparams = !p4hir.action<()> -!action_params = !p4hir.action<(!p4hir.int<42>, !ref, !p4hir.int<42>, !p4hir.bool)> +!action_noparams = !p4hir.func<()> +!action_params = !p4hir.func<(!p4hir.int<42>, !ref, !p4hir.int<42>, !p4hir.bool)> // No need to check stuff. If it parses, it's fine. // CHECK: module diff --git a/test/Translate/Ops/action.p4 b/test/Translate/Ops/action.p4 index 4f1f33c..4e3b2fc 100644 --- a/test/Translate/Ops/action.p4 +++ b/test/Translate/Ops/action.p4 @@ -1,6 +1,6 @@ // RUN: p4mlir-translate --typeinference-only %s | FileCheck %s -// CHECK-LABEL: p4hir.action @foo(%arg0: !p4hir.bit<16> {p4hir.dir = #p4hir}, %arg1: !p4hir.ref> {p4hir.dir = #p4hir}, %arg2: !p4hir.ref> {p4hir.dir = #p4hir}, %arg3: !p4hir.bit<16> {p4hir.dir = #p4hir}) +// CHECK-LABEL: p4hir.func action @foo(%arg0: !p4hir.bit<16> {p4hir.dir = #p4hir}, %arg1: !p4hir.ref> {p4hir.dir = #p4hir}, %arg2: !p4hir.ref> {p4hir.dir = #p4hir}, %arg3: !p4hir.bit<16> {p4hir.dir = #p4hir}) // CHECK: p4hir.return action foo(in bit<16> arg1, inout int<10> arg2, out bit<16> arg3, bit<16> arg4) { bit<16> x = arg1; @@ -11,7 +11,7 @@ action foo(in bit<16> arg1, inout int<10> arg2, out bit<16> arg3, bit<16> arg4) return; } -// CHECK-LABEL: p4hir.action @bar(%arg0: !p4hir.bit<16> {p4hir.dir = #p4hir}) { +// CHECK-LABEL: p4hir.func action @bar(%arg0: !p4hir.bit<16> {p4hir.dir = #p4hir}) { // CHECK: p4hir.return action bar(bit<16> arg1) { } diff --git a/test/Translate/Ops/assign.p4 b/test/Translate/Ops/assign.p4 index 0301463..35c31a6 100644 --- a/test/Translate/Ops/assign.p4 +++ b/test/Translate/Ops/assign.p4 @@ -12,7 +12,7 @@ action assign() { res = lhs + rhs; } -// CHECK-LABEL: p4hir.action @assign() +// CHECK-LABEL: p4hir.func action @assign() // CHECK: %[[VAL_0:.*]] = p4hir.alloca !p4hir.bit<10> ["res"] : !p4hir.ref> // CHECK: %[[VAL_1:.*]] = p4hir.const #p4hir.int<1> : !p4hir.bit<10> // CHECK: %[[VAL_2:.*]] = p4hir.cast(%[[VAL_1]] : !p4hir.bit<10>) : !p4hir.bit<10> diff --git a/test/Translate/Ops/binop.p4 b/test/Translate/Ops/binop.p4 index 7403e13..7848fcc 100644 --- a/test/Translate/Ops/binop.p4 +++ b/test/Translate/Ops/binop.p4 @@ -53,7 +53,7 @@ action int_binops() { int<10> r13 = lhs ^ rhs; } -// CHECK-LABEL: p4hir.action @bit_binops() +// CHECK-LABEL: p4hir.func action @bit_binops() // CHECK: %[[VAL_0:.*]] = p4hir.alloca !p4hir.bit<10> ["res"] : !p4hir.ref> // CHECK: %[[VAL_1:.*]] = p4hir.const #p4hir.int<1> : !p4hir.bit<10> // CHECK: %[[VAL_2:.*]] = p4hir.cast(%[[VAL_1]] : !p4hir.bit<10>) : !p4hir.bit<10> @@ -139,7 +139,7 @@ action int_binops() { // CHECK: %[[VAL_67:.*]] = p4hir.binop(xor, %[[VAL_65]], %[[VAL_66]]) : !p4hir.bit<10> // CHECK: %[[VAL_68:.*]] = p4hir.alloca !p4hir.bit<10> ["r14", init] : !p4hir.ref> // CHECK: p4hir.store %[[VAL_67]], %[[VAL_68]] : !p4hir.bit<10>, !p4hir.ref> -// CHECK-LABEL: p4hir.action @int_binops() +// CHECK-LABEL: p4hir.func action @int_binops() // CHECK: %[[VAL_69:.*]] = p4hir.alloca !p4hir.int<10> ["res"] : !p4hir.ref> // CHECK: %[[VAL_70:.*]] = p4hir.const #p4hir.int<1> : !p4hir.int<10> // CHECK: %[[VAL_71:.*]] = p4hir.cast(%[[VAL_70]] : !p4hir.int<10>) : !p4hir.int<10> diff --git a/test/Translate/Ops/calls.p4 b/test/Translate/Ops/calls.p4 index 25ddcea..350817b 100644 --- a/test/Translate/Ops/calls.p4 +++ b/test/Translate/Ops/calls.p4 @@ -36,13 +36,13 @@ action bazz(in int<16> arg1) { // CHECK: p4hir.call @bar() : () -> () bar(); // CHECK: p4hir.scope - // CHECK: %[[VAR_A:.*]] = p4hir.alloca !p4hir.int<16> ["a"] : !p4hir.ref> + // CHECK: %[[VAR_A:.*]] = p4hir.alloca !p4hir.int<16> ["a_out"] : !p4hir.ref> // CHECK: p4hir.call @quuz(%[[VAR_A]]) : (!p4hir.ref>) -> () // CHECK: p4hir.load %[[VAR_A]] : !p4hir.ref>, !p4hir.int<16> int<16> val; quuz(val); // CHECK: p4hir.scope - // CHECK: %[[VAR_X:.*]] = p4hir.alloca !p4hir.int<16> ["x", init] : !p4hir.ref> + // CHECK: %[[VAR_X:.*]] = p4hir.alloca !p4hir.int<16> ["x_inout", init] : !p4hir.ref> // CHECK: %[[VAL_X:.*]] = p4hir.load %[[VAL:.*]] : !p4hir.ref>, !p4hir.int<16> // CHECK: p4hir.store %[[VAL_X]], %[[VAR_X]] : !p4hir.int<16>, !p4hir.ref> // CHECK: p4hir.call @baz(%[[VAR_X]]) diff --git a/test/Translate/Ops/cmp.p4 b/test/Translate/Ops/cmp.p4 index f738234..613e69b 100644 --- a/test/Translate/Ops/cmp.p4 +++ b/test/Translate/Ops/cmp.p4 @@ -2,7 +2,7 @@ // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -// CHECK-LABEL: p4hir.action @cmp() +// CHECK-LABEL: p4hir.func action @cmp() // CHECK: %[[VAL_0:.*]] = p4hir.alloca !p4hir.bool ["res"] : !p4hir.ref // CHECK: %[[VAL_1:.*]] = p4hir.const #p4hir.int<1> : !p4hir.bit<10> // CHECK: %[[VAL_2:.*]] = p4hir.cast(%[[VAL_1]] : !p4hir.bit<10>) : !p4hir.bit<10> diff --git a/test/Translate/Ops/function.p4 b/test/Translate/Ops/function.p4 new file mode 100644 index 0000000..71c17b2 --- /dev/null +++ b/test/Translate/Ops/function.p4 @@ -0,0 +1,64 @@ +// RUN: p4mlir-translate --typeinference-only %s | FileCheck %s + +// CHECK-LABEL: p4hir.func @max(%arg0: !p4hir.bit<16> {p4hir.dir = #p4hir}, %arg1: !p4hir.bit<16> {p4hir.dir = #p4hir}) -> !p4hir.bit<16> +// CHECK: %[[CMP:.*]] = p4hir.cmp(gt, %arg0, %arg1) : !p4hir.bit<16>, !p4hir.bool +// CHECK: p4hir.if %[[CMP]] { +// CHECK: p4hir.return %arg0 : !p4hir.bit<16> +// CHECK: } +// CHECK: p4hir.return %arg1 : !p4hir.bit<16> +// CHECK: } + +bit<16> max(in bit<16> left, in bit<16> right) { + if (left > right) + return left; + return right; +} + +// CHECK-LABEL: p4hir.func action @bar(%arg0: !p4hir.bit<16> {p4hir.dir = #p4hir}, %arg1: !p4hir.bit<16> {p4hir.dir = #p4hir}, %arg2: !p4hir.ref> {p4hir.dir = #p4hir}) { +// CHECK: %[[CALL:.*]] = p4hir.call @max(%arg0, %arg1) : (!p4hir.bit<16>, !p4hir.bit<16>) -> !p4hir.bit<16> +// CHECK: p4hir.store %[[CALL]], %arg2 : !p4hir.bit<16>, !p4hir.ref> +// CHECK: p4hir.return + +action bar(in bit<16> arg1, in bit<16> arg2, out bit<16> res) { + res = max(arg1, arg2); +} + +// Example from P4 language spec (6.8. Calling convention: call by copy in/copy out) +// The function call is equivalent to the following sequence of statements: +// bit tmp1 = a; // evaluate a; save result +// bit tmp2 = g(a); // evaluate g(a); save result; modifies a +// f(tmp1, tmp2); // evaluate f; modifies tmp1 +// a = tmp1; // copy inout result back into a +// However, we limit the scope of temporaries via structured control flow + +// CHECK: p4hir.func @f(!p4hir.ref> {p4hir.dir = #p4hir}, !p4hir.bit<1> {p4hir.dir = #p4hir}) +extern void f(inout bit x, in bit y); +// CHECK: p4hir.func @g(!p4hir.ref> {p4hir.dir = #p4hir}) -> !p4hir.bit<1> +extern bit g(inout bit z); + +action test_param() { + bit a; + f(a, g(a)); +} + +// CHECK-LABEL: p4hir.func action @test_param() { +// CHECK: %[[A:.*]] = p4hir.alloca !p4hir.bit<1> ["a"] : !p4hir.ref> +// CHECK: p4hir.scope { +// CHECK: %[[X_INOUT:.*]] = p4hir.alloca !p4hir.bit<1> ["x_inout", init] : !p4hir.ref> +// CHECK: %[[A_VAL:.*]] = p4hir.load %[[A]] : !p4hir.ref>, !p4hir.bit<1> +// CHECK: p4hir.store %[[A_VAL]], %[[X_INOUT]] : !p4hir.bit<1>, !p4hir.ref> +// CHECK: %[[G_VAL:.*]] = p4hir.scope { +// CHECK: %[[Z_INOUT:.*]] = p4hir.alloca !p4hir.bit<1> ["z_inout", init] : !p4hir.ref> +// CHECK: %[[A_VAL2:.*]] = p4hir.load %[[A]] : !p4hir.ref>, !p4hir.bit<1> +// CHECK: p4hir.store %[[A_VAL2]], %[[Z_INOUT]] : !p4hir.bit<1>, !p4hir.ref> +// CHECK: %[[G_RES:.*]] = p4hir.call @g(%[[Z_INOUT]]) : (!p4hir.ref>) -> !p4hir.bit<1> +// CHECK: %[[A_OUT_VAL:.*]] = p4hir.load %[[Z_INOUT]] : !p4hir.ref>, !p4hir.bit<1> +// CHECK: p4hir.store %[[A_OUT_VAL]], %[[A]] : !p4hir.bit<1>, !p4hir.ref> +// CHECK: p4hir.yield %[[G_RES]] : !p4hir.bit<1> +// CHECK: } : !p4hir.bit<1> +// CHECK: p4hir.call @f(%[[X_INOUT]], %[[G_VAL]]) : (!p4hir.ref>, !p4hir.bit<1>) -> () +// CHECK: %[[A_OUT_VAL2:.*]] = p4hir.load %[[X_INOUT]] : !p4hir.ref>, !p4hir.bit<1> +// CHECK: p4hir.store %[[A_OUT_VAL2]], %[[A]] : !p4hir.bit<1>, !p4hir.ref> +// CHECK: } +// CHECK: p4hir.return +// CHECK: } diff --git a/test/Translate/Ops/scope.p4 b/test/Translate/Ops/scope.p4 index 8e58d75..e952cd6 100644 --- a/test/Translate/Ops/scope.p4 +++ b/test/Translate/Ops/scope.p4 @@ -1,6 +1,6 @@ // RUN: p4mlir-translate --typeinference-only %s | FileCheck %s -// CHECK-LABEL: p4hir.action @scope() +// CHECK-LABEL: p4hir.func action @scope() action scope() { bool res; // Outer alloca diff --git a/test/Translate/Ops/unop.p4 b/test/Translate/Ops/unop.p4 index 5e9860b..c414934 100644 --- a/test/Translate/Ops/unop.p4 +++ b/test/Translate/Ops/unop.p4 @@ -2,7 +2,7 @@ // NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -// CHECK-LABEL: p4hir.action @foo() +// CHECK-LABEL: p4hir.func action @foo() // CHECK: %[[VAL_0:.*]] = p4hir.const #p4hir.bool : !p4hir.bool // CHECK: %[[VAL_1:.*]] = p4hir.alloca !p4hir.bool ["b0", init] : !p4hir.ref // CHECK: p4hir.store %[[VAL_0]], %[[VAL_1]] : !p4hir.bool, !p4hir.ref diff --git a/test/Translate/Ops/variables.p4 b/test/Translate/Ops/variables.p4 index b60672e..19c6419 100644 --- a/test/Translate/Ops/variables.p4 +++ b/test/Translate/Ops/variables.p4 @@ -17,7 +17,7 @@ action foo() { bit<8> b10 = (bit<8>)b8; } -// CHECK-LABEL: p4hir.action @foo() +// CHECK-LABEL: p4hir.func action @foo() // CHECK: %[[VAL_0:.*]] = p4hir.const #p4hir.int<255> : !p4hir.bit<32> // CHECK: %[[VAL_1:.*]] = p4hir.alloca !p4hir.bit<32> ["b0", init] : !p4hir.ref> // CHECK: p4hir.store %[[VAL_0]], %[[VAL_1]] : !p4hir.bit<32>, !p4hir.ref> diff --git a/tools/p4mlir-translate/translate.cpp b/tools/p4mlir-translate/translate.cpp index b3e10f3..60b9fdc 100644 --- a/tools/p4mlir-translate/translate.cpp +++ b/tools/p4mlir-translate/translate.cpp @@ -3,9 +3,6 @@ #include #include -#include "ir/ir-generated.h" -#include "p4mlir/third_party/llvm-project/mlir/include/mlir/IR/Value.h" - #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wcovered-switch-default" #include "frontends/common/resolveReferences/resolveReferences.h" @@ -132,6 +129,8 @@ class P4TypeConverter : public P4::Inspector { bool preorder(const P4::IR::Type_Name *name) override; bool preorder(const P4::IR::Type_Action *act) override; + bool preorder(const P4::IR::Type_Method *m) override; + bool preorder(const P4::IR::Type_Void *v) override; mlir::Type getType() const { return type; } bool setType(const P4::IR::Type *type, mlir::Type mlirType); @@ -153,7 +152,8 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext { // llvm::DenseMap p4Constants; llvm::DenseMap p4Constants; llvm::DenseMap p4Values; - using P4Symbol = std::variant; + using P4Symbol = + std::variant; // TODO: Implement better scoped symbol table llvm::DenseMap p4Symbols; @@ -253,10 +253,12 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext { } mlir::Value setValue(const P4::IR::Node *node, mlir::Value value) { + if (!value) return value; + if (LOGGING(4)) { std::string s; llvm::raw_string_ostream os(s); - value.print(os); + value.print(os, mlir::OpPrintingFlags().assumeVerified()); LOG4("Converted " << dbp(node) << " -> \"" << s << "\""); } @@ -281,6 +283,8 @@ class P4HIRConverter : public P4::Inspector, public P4::ResolutionContext { bool preorder(const P4::IR::P4Program *) override { return true; } bool preorder(const P4::IR::P4Action *a) override; + bool preorder(const P4::IR::Function *f) override; + bool preorder(const P4::IR::Method *m) override; bool preorder(const P4::IR::BlockStatement *block) override { // If this is a top-level block where scope is implied (e.g. function, // action, certain statements) do not create explicit scope. @@ -420,7 +424,35 @@ bool P4TypeConverter::preorder(const P4::IR::Type_Action *type) { argTypes.push_back(p->hasOut() ? P4HIR::ReferenceType::get(type) : type); } - auto mlirType = P4HIR::ActionType::get(converter.context(), argTypes); + auto mlirType = P4HIR::FuncType::get(converter.context(), argTypes); + return setType(type, mlirType); +} + +bool P4TypeConverter::preorder(const P4::IR::Type_Method *type) { + if ((this->type = converter.findType(type))) return false; + + ConversionTracer trace("TypeConverting ", type); + llvm::SmallVector argTypes; + + CHECK_NULL(type->parameters); + CHECK_NULL(type->returnType); + + mlir::Type resultType = convert(type->returnType); + + for (const auto *p : type->parameters->parameters) { + mlir::Type type = convert(p->type); + argTypes.push_back(p->hasOut() ? P4HIR::ReferenceType::get(type) : type); + } + + auto mlirType = P4HIR::FuncType::get(argTypes, resultType); + return setType(type, mlirType); +} + +bool P4TypeConverter::preorder(const P4::IR::Type_Void *type) { + if ((this->type = converter.findType(type))) return false; + + ConversionTracer trace("TypeConverting ", type); + auto mlirType = P4HIR::VoidType::get(converter.context()); return setType(type, mlirType); } @@ -691,19 +723,11 @@ bool P4HIRConverter::preorder(const P4::IR::IfStatement *ifs) { return false; } -bool P4HIRConverter::preorder(const P4::IR::P4Action *act) { - ConversionTracer trace("Converting ", act); - - // TODO: Actions might reference some control locals, we need to make - // them visible somehow (e.g. via additional arguments) - - // FIXME: Get rid of typeMap: ensure action knows its type - auto actType = mlir::cast(getOrCreateType(typeMap->getType(act, true))); - const auto ¶ms = act->getParameters()->parameters; - +static llvm::SmallVector convertParamDirections( + const P4::IR::ParameterList *params, mlir::MLIRContext *ctxt) { // Create attributes for directions llvm::SmallVector argAttrs; - for (const auto *p : params) { + for (const auto *p : params->parameters) { P4HIR::ParamDirection dir; switch (p->direction) { case P4::IR::Direction::None: @@ -721,16 +745,90 @@ bool P4HIRConverter::preorder(const P4::IR::P4Action *act) { }; mlir::NamedAttribute dirAttr( - mlir::StringAttr::get(context(), P4HIR::ActionOp::getDirectionAttrName()), - P4HIR::ParamDirectionAttr::get(context(), dir)); + mlir::StringAttr::get(ctxt, P4HIR::FuncOp::getDirectionAttrName()), + P4HIR::ParamDirectionAttr::get(ctxt, dir)); + + argAttrs.emplace_back(mlir::DictionaryAttr::get(ctxt, dirAttr)); + } + + return argAttrs; +} + +bool P4HIRConverter::preorder(const P4::IR::Function *f) { + ConversionTracer trace("Converting ", f); + + auto funcType = mlir::cast(getOrCreateType(f->type)); + const auto ¶ms = f->getParameters()->parameters; + + auto argAttrs = convertParamDirections(f->getParameters(), context()); + assert(funcType.getNumInputs() == argAttrs.size() && "invalid parameter conversion"); + + auto func = builder.create(getLoc(builder, f), f->name.string_view(), funcType, + llvm::ArrayRef(), argAttrs); + func.createEntryBlock(); + + // Iterate over parameters again binding parameter values to arguments of first BB + auto &body = func.getBody(); + + assert(body.getNumArguments() == params.size() && "invalid parameter conversion"); + for (auto [param, bodyArg] : llvm::zip(params, body.getArguments())) setValue(param, bodyArg); - argAttrs.emplace_back(mlir::DictionaryAttr::get(context(), dirAttr)); + // We cannot simply visit each node of the top-level block as + // ResolutionContext would not be able to resolve declarations there + // (sic!) + { + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&body.front()); + visit(f->body); + + // Check if body's last block is not terminated. + mlir::Block &b = body.back(); + if (!b.mightHaveTerminator()) { + builder.setInsertionPointToEnd(&b); + builder.create(getEndLoc(builder, f)); + } } + + auto [it, inserted] = p4Symbols.try_emplace(f, mlir::SymbolRefAttr::get(func)); + BUG_CHECK(inserted, "duplicate translation of %1%", f); + + return false; +} + +// We treat method as an external function (w/o body) +bool P4HIRConverter::preorder(const P4::IR::Method *m) { + ConversionTracer trace("Converting ", m); + + auto funcType = mlir::cast(getOrCreateType(m->type)); + + auto argAttrs = convertParamDirections(m->getParameters(), context()); + assert(funcType.getNumInputs() == argAttrs.size() && "invalid parameter conversion"); + + auto func = builder.create(getLoc(builder, m), m->name.string_view(), funcType, + llvm::ArrayRef(), argAttrs); + + auto [it, inserted] = p4Symbols.try_emplace(m, mlir::SymbolRefAttr::get(func)); + BUG_CHECK(inserted, "duplicate translation of %1%", m); + + return false; +} + +bool P4HIRConverter::preorder(const P4::IR::P4Action *act) { + ConversionTracer trace("Converting ", act); + + // TODO: Actions might reference some control locals, we need to make + // them visible somehow (e.g. via additional arguments) + + // FIXME: Get rid of typeMap: ensure action knows its type + auto actType = mlir::cast(getOrCreateType(typeMap->getType(act, true))); + const auto ¶ms = act->getParameters()->parameters; + + auto argAttrs = convertParamDirections(act->getParameters(), context()); assert(actType.getNumInputs() == argAttrs.size() && "invalid parameter conversion"); auto action = - builder.create(getLoc(builder, act), act->name.string_view(), actType, - llvm::ArrayRef(), argAttrs); + P4HIR::FuncOp::buildAction(builder, getLoc(builder, act), act->name.string_view(), actType, + llvm::ArrayRef(), argAttrs); // Iterate over parameters again binding parameter values to arguments of first BB auto &body = action.getBody(); @@ -785,7 +883,9 @@ bool P4HIRConverter::preorder(const P4::IR::MethodCallExpression *mce) { // model copy-in/out semantics. To limit the lifetime of those temporaries, do // everything in the dedicated block scope. If there are no out parameters, // then emit everything direct. - auto convertCall = [&](mlir::OpBuilder &b, mlir::Location loc) { + bool emitScope = + std::any_of(params.begin(), params.end(), [](const auto *p) { return p->hasOut(); }); + auto convertCall = [&](mlir::OpBuilder &b, mlir::Type &resultType, mlir::Location loc) { llvm::SmallVector operands; for (auto [idx, arg] : llvm::enumerate(*mce->arguments)) { ConversionTracer trace("Converting ", arg); @@ -803,8 +903,10 @@ bool P4HIRConverter::preorder(const P4::IR::MethodCallExpression *mce) { auto ref = resolveReference(arg->expression); auto type = mlir::cast(ref.getType()); - auto copyIn = b.create(loc, type, type.getObjectType(), - params[idx]->name.string_view()); + auto copyIn = b.create( + loc, type, type.getObjectType(), + llvm::Twine(params[idx]->name.string_view()) + + (dir == P4::IR::Direction::InOut ? "_inout" : "_out")); if (dir == P4::IR::Direction::InOut) { copyIn.setInit(true); @@ -817,13 +919,28 @@ bool P4HIRConverter::preorder(const P4::IR::MethodCallExpression *mce) { operands.push_back(argVal); } + mlir::Value callResult; if (const auto *actCall = instance->to()) { auto actSym = p4Symbols.lookup(actCall->action); BUG_CHECK(actSym, "expected reference action to be converted: %1%", actCall->action); b.create(loc, actSym, operands); + } else if (const auto *fCall = instance->to()) { + auto fSym = p4Symbols.lookup(fCall->function); + auto callResultType = getOrCreateType(instance->originalMethodType->returnType); + + BUG_CHECK(fSym, "expected reference function to be converted: %1%", fCall->function); + + callResult = b.create(loc, fSym, callResultType, operands).getResult(); + } else if (const auto *fCall = instance->to()) { + auto fSym = p4Symbols.lookup(fCall->method); + auto callResultType = getOrCreateType(instance->originalMethodType->returnType); + + BUG_CHECK(fSym, "expected reference function to be converted: %1%", fCall->method); + + callResult = b.create(loc, fSym, callResultType, operands).getResult(); } else { - BUG("unsupported call type: %1%", instance); + BUG("unsupported call type: %1%", mce); } for (auto [idx, arg] : llvm::enumerate(*mce->arguments)) { @@ -836,15 +953,25 @@ bool P4HIRConverter::preorder(const P4::IR::MethodCallExpression *mce) { getEndLoc(builder, mce), builder.create(getEndLoc(builder, mce), copyOut), dest); } + + // If we are inside the scope, then build the yield of the call result + if (emitScope) { + if (callResult) { + resultType = callResult.getType(); + b.create(getEndLoc(b, mce), callResult); + } else + b.create(getEndLoc(b, mce)); + } else { + setValue(mce, callResult); + } }; - if (std::any_of(params.begin(), params.end(), [](const auto *p) { return p->hasOut(); })) { - mlir::OpBuilder::InsertionGuard guard(builder); + if (emitScope) { auto scope = builder.create(getLoc(builder, mce), convertCall); - builder.setInsertionPointToEnd(&scope.getScopeRegion().back()); - builder.create(getEndLoc(builder, mce)); + setValue(mce, scope.getResults()); } else { - convertCall(builder, getLoc(builder, mce)); + mlir::Type resultType; + convertCall(builder, resultType, getLoc(builder, mce)); } return false;