diff --git a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td index 8c61b8be..364655f8 100644 --- a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td +++ b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td @@ -265,16 +265,31 @@ def XTenNN_KernelOp : XTenNN_Op<"kernel", []> { ``` %c = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) {attr = 4 : i32} -> tensor<2xi64> %d:2 = xten_nn.kernel "frob" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64> + %e = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {N = 42 : i32} {attr = 4 : i32} -> tensor<2xi64> ``` }]; let arguments = (ins Variadic:$arguments, - StrAttr:$name + StrAttr:$name, + OptionalAttr:$instantiation_args, + OptionalAttr:$instantiation_arg_names ); let results = (outs Variadic:$results); + let builders = [ + OpBuilder<(ins + "::mlir::TypeRange":$results, + "::mlir::ValueRange":$arguments, + "::llvm::StringRef":$name), [{ + build($_builder, $_state, results, arguments, name, + ::mlir::ArrayAttr(), ::mlir::ArrayAttr()); + }] + > + ]; + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index cc4c67af..2af2a988 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -10,23 +10,31 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ADT/SmallVector.h" +#include "xten/Dialect/XTenNN/IR/XTenNNOps.h" + #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/FunctionImplementation.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" #include "xten/Dialect/XTenNN/IR/XTenNN.h" #include "xten/Dialect/XTenNN/IR/XTenNNBase.h" -#include "xten/Dialect/XTenNN/IR/XTenNNOps.h" #include "xten/Dialect/XTenNN/Interfaces/EnclaveOpInterfaces.h" + #include using namespace mlir; @@ -215,15 +223,87 @@ static void printKernelArgumentList(OpAsmPrinter &p, TypeRange types, } // Parse -// $name custom(type($arguments), $arguments) attr-dict +// [((name = )?value, )*((name = )?value)] +static ParseResult parseKernelInstantiationArgs(OpAsmParser &p, + SmallVector &values, + SmallVector &names) { + if (failed(p.parseLSquare())) + return failure(); + + if (failed(p.parseCommaSeparatedList([&p, &names, &values]() { + std::string name; + if (succeeded(p.parseOptionalString(&name))) { + names.push_back(StringAttr::get(p.getContext(), name)); + if (failed(p.parseEqual())) + return failure(); + } + Attribute attr; + auto res = p.parseOptionalAttribute(attr); + if (res.has_value() && succeeded(*res)) { + values.push_back(attr); + } + if (res.has_value() && failed(*res)) + return failure(); + + return success(); + }))) { + return failure(); + } + + if (failed(p.parseRSquare())) + return failure(); + + return success(); +} + +// Print +// instantiation_args [((name = )?value, )*((name = )?value)] +static void +printKernelInstantiationArgs(OpAsmPrinter &p, + ArrayRef instantiationArgs, + ArrayRef instantiationArgNames) { + if (!instantiationArgs.empty()) { + p << "instantiation_args ["; + auto zipped = llvm::zip_longest(instantiationArgNames, instantiationArgs); + for (auto iter = zipped.begin(); iter != zipped.end(); ++iter) { + if (iter != zipped.begin()) + p << ", "; + auto [name, value] = *iter; + if (name) + p << *name << " = "; + if (value) + p.printAttribute(*value); + } + p << ']'; + } +} + +// Parse +// $name custom(type($arguments), $arguments) +// (instantiation_args custom)? attr-dict // `->` type($results) ParseResult KernelOp::parse(OpAsmParser &p, OperationState &result) { StringAttr name; if (p.parseAttribute(name, "name", result.attributes)) return failure(); - if (parseKernelArgumentList(p, result.operands) || - p.parseOptionalAttrDict(result.attributes)) + if (parseKernelArgumentList(p, result.operands)) + return failure(); + + if (succeeded(p.parseOptionalKeyword("instantiation_args"))) { + SmallVector values; + SmallVector names; + if (failed(parseKernelInstantiationArgs(p, values, names))) + return failure(); + result.addAttribute("instantiation_args", + ArrayAttr::get(p.getContext(), values)); + if (!names.empty()) { + result.addAttribute("instantiation_arg_names", + ArrayAttr::get(p.getContext(), names)); + } + } + + if (p.parseOptionalAttrDict(result.attributes)) return failure(); // If the op has no results, the `-> type($results)` is absent. @@ -236,8 +316,9 @@ ParseResult KernelOp::parse(OpAsmParser &p, OperationState &result) { return success(); } -// Parse -// $name custom(type($arguments), $arguments) attr-dict +// Print +// $name custom(type($arguments), $arguments) +// (instantiation_args custom)? attr-dict // `->` type($results) void KernelOp::print(OpAsmPrinter &p) { p << ' '; @@ -245,9 +326,26 @@ void KernelOp::print(OpAsmPrinter &p) { p << ' '; printKernelArgumentList(p, getOperandTypes(), getOperands()); p << ' '; - SmallVector elidedAttrs = {"name"}; + auto instantiationArgs = getInstantiationArgs(); + auto instantiationArgNames = getInstantiationArgNames(); + if (instantiationArgs != std::nullopt) { + printKernelInstantiationArgs(p, instantiationArgs->getValue(), + (instantiationArgNames == std::nullopt) + ? ArrayRef() + : instantiationArgNames->getValue()); + p << ' '; + } + + SmallVector elidedAttrs = {"name", "instantiation_args", + "instantiation_arg_names"}; p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs); - if (getOperation()->getAttrs().size() > elidedAttrs.size()) + if (llvm::any_of( + getOperation()->getAttrs(), [&elidedAttrs](NamedAttribute a) { + auto name = a.getName(); + return llvm::any_of(elidedAttrs, [&name](StringRef elidedName) { + return name == elidedName; + }); + })) p << ' '; if (getNumResults()) { p << "-> "; @@ -255,6 +353,23 @@ void KernelOp::print(OpAsmPrinter &p) { } } +LogicalResult KernelOp::verify() { + if (getInstantiationArgNames().has_value()) { + if (!getInstantiationArgs().has_value()) { + return emitOpError( + "cannot have instantiation arg names without instantiation args"); + } + if (!(getInstantiationArgNamesAttr().empty() || + getInstantiationArgNamesAttr().size() == + getInstantiationArgsAttr().size())) { + return emitOpError("instantiation arg names must be either empty or as " + "long as instantiation args"); + } + } + + return success(); +} + #define GET_OP_CLASSES #include "xten/Dialect/XTenNN/IR/XTenNNOps.cpp.inc" @@ -266,9 +381,7 @@ ParseResult SubgraphOp::parse(OpAsmParser &p, OperationState &result) { return parseEnclaveOp(p, result); } -void SubgraphOp::print(OpAsmPrinter &p) { - printEnclaveOp(p, *this); -} +void SubgraphOp::print(OpAsmPrinter &p) { printEnclaveOp(p, *this); } LogicalResult SubgraphOp::verify() { Block *optBody = this->getOptionalEnclaveBody(); diff --git a/test/Dialect/XTenNN/ops.mlir b/test/Dialect/XTenNN/ops.mlir index ce152adf..b6ea7969 100644 --- a/test/Dialect/XTenNN/ops.mlir +++ b/test/Dialect/XTenNN/ops.mlir @@ -34,6 +34,12 @@ func.func @kernel(%arg0: tensor<2xi64>, %arg1 : tensor<4xi64>) { // CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) {attr = 4 : i32} -> tensor<2xi64> %d:2 = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64> // CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64> + %e = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args ["N" = 42 : i32, "idx" = 56 : index] {attr = 4 : i32} -> tensor<2xi64> + // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args ["N" = 42 : i32, "idx" = 56 : index] {attr = 4 : i32} -> tensor<2xi64> + %f = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args [42 : i32, 56 : index] {attr = 4 : i32} -> tensor<2xi64> + // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args [42 : i32, 56 : index] {attr = 4 : i32} -> tensor<2xi64> + %g = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args [42, 56, 1.0] {attr = 4 : i32} -> tensor<2xi64> + // CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args [42 : i64, 56 : i64, 1.000000e+00 : f64] {attr = 4 : i32} -> tensor<2xi64> return } diff --git a/test/Dialect/XTenNN/ops_invalid.mlir b/test/Dialect/XTenNN/ops_invalid.mlir index ffd91190..6815f172 100644 --- a/test/Dialect/XTenNN/ops_invalid.mlir +++ b/test/Dialect/XTenNN/ops_invalid.mlir @@ -68,6 +68,31 @@ func.func @kernel_missing_result(%arg0: i8, %arg1: i8) { xten_nn.kernel "myKernel" () -> } +// ----- + +func.func @kernel_instantiation_list_different_length(%arg0: i8, %arg1: i8) { + // expected-error@+1 {{instantiation arg names must be either empty or as long as instantiation args}} + %x = xten_nn.kernel "myKernel" () instantiation_args ["N" = 42 : i32, 51 : index] -> i32 + return +} + +// ----- + +func.func @kernel_instantiation_list_non_empty(%arg0: i8, %arg1: i8) { + // expected-error@+1 {{cannot have instantiation arg names without instantiation args}} + %x = xten_nn.kernel "myKernel" () {instantiation_arg_names = ["N"]} -> i32 + return +} + +// ----- + +func.func @kernel_instantiation_list_non_quoted(%arg0: i8, %arg1: i8) { + // expected-error@+1 {{expected ']'}} + %x = xten_nn.kernel "myKernel" () instantiation_args [N = 42 : i32] -> i32 + return +} + + // ----- func.func @topk_wrong_output_shape(%arg0: tensor<10x10xf32>) {