Skip to content

Commit

Permalink
Merge pull request #94 from Xilinx/corentin.kernel_template
Browse files Browse the repository at this point in the history
[FXML-5060] Add instantiation arguments to xten_nn.kernel op
  • Loading branch information
mgehre-amd authored Oct 4, 2024
2 parents b58d735 + bb61b2f commit a08963c
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 13 deletions.
17 changes: 16 additions & 1 deletion include/xten/Dialect/XTenNN/IR/XTenNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,31 @@ def XTenNN_KernelOp : XTenNN_Op<"kernel", []> {
```
%c = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) {attr = 4 : i32} -> tensor<2xi64>
%d:2 = xten_nn.kernel "frob" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64>
%e = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args {N = 42 : i32} {attr = 4 : i32} -> tensor<2xi64>
```
}];

let arguments = (ins
Variadic<AnyType>:$arguments,
StrAttr:$name
StrAttr:$name,
OptionalAttr<ArrayAttr>:$instantiation_args,
OptionalAttr<ArrayAttr>:$instantiation_arg_names
);
let results = (outs Variadic<AnyType>:$results);

let builders = [
OpBuilder<(ins
"::mlir::TypeRange":$results,
"::mlir::ValueRange":$arguments,
"::llvm::StringRef":$name), [{
build($_builder, $_state, results, arguments, name,
::mlir::ArrayAttr(), ::mlir::ArrayAttr());
}]
>
];

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
137 changes: 125 additions & 12 deletions lib/Dialect/XTenNN/IR/XTenNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,31 @@
//
//===----------------------------------------------------------------------===//

#include "llvm/ADT/SmallVector.h"
#include "xten/Dialect/XTenNN/IR/XTenNNOps.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Support/LogicalResult.h"

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"

#include "xten/Dialect/XTenNN/IR/XTenNN.h"
#include "xten/Dialect/XTenNN/IR/XTenNNBase.h"
#include "xten/Dialect/XTenNN/IR/XTenNNOps.h"
#include "xten/Dialect/XTenNN/Interfaces/EnclaveOpInterfaces.h"

#include <cstdint>

using namespace mlir;
Expand Down Expand Up @@ -215,15 +223,87 @@ static void printKernelArgumentList(OpAsmPrinter &p, TypeRange types,
}

// Parse
// $name custom<KernelArgumentList>(type($arguments), $arguments) attr-dict
// [((name = )?value, )*((name = )?value)]
static ParseResult parseKernelInstantiationArgs(OpAsmParser &p,
SmallVector<Attribute> &values,
SmallVector<Attribute> &names) {
if (failed(p.parseLSquare()))
return failure();

if (failed(p.parseCommaSeparatedList([&p, &names, &values]() {
std::string name;
if (succeeded(p.parseOptionalString(&name))) {
names.push_back(StringAttr::get(p.getContext(), name));
if (failed(p.parseEqual()))
return failure();
}
Attribute attr;
auto res = p.parseOptionalAttribute(attr);
if (res.has_value() && succeeded(*res)) {
values.push_back(attr);
}
if (res.has_value() && failed(*res))
return failure();

return success();
}))) {
return failure();
}

if (failed(p.parseRSquare()))
return failure();

return success();
}

// Print
// instantiation_args [((name = )?value, )*((name = )?value)]
static void
printKernelInstantiationArgs(OpAsmPrinter &p,
ArrayRef<Attribute> instantiationArgs,
ArrayRef<Attribute> instantiationArgNames) {
if (!instantiationArgs.empty()) {
p << "instantiation_args [";
auto zipped = llvm::zip_longest(instantiationArgNames, instantiationArgs);
for (auto iter = zipped.begin(); iter != zipped.end(); ++iter) {
if (iter != zipped.begin())
p << ", ";
auto [name, value] = *iter;
if (name)
p << *name << " = ";
if (value)
p.printAttribute(*value);
}
p << ']';
}
}

// Parse
// $name custom<KernelArgumentList>(type($arguments), $arguments)
// (instantiation_args custom<InstantiationArgs>)? attr-dict
// `->` type($results)
ParseResult KernelOp::parse(OpAsmParser &p, OperationState &result) {
StringAttr name;
if (p.parseAttribute(name, "name", result.attributes))
return failure();

if (parseKernelArgumentList(p, result.operands) ||
p.parseOptionalAttrDict(result.attributes))
if (parseKernelArgumentList(p, result.operands))
return failure();

if (succeeded(p.parseOptionalKeyword("instantiation_args"))) {
SmallVector<Attribute> values;
SmallVector<Attribute> names;
if (failed(parseKernelInstantiationArgs(p, values, names)))
return failure();
result.addAttribute("instantiation_args",
ArrayAttr::get(p.getContext(), values));
if (!names.empty()) {
result.addAttribute("instantiation_arg_names",
ArrayAttr::get(p.getContext(), names));
}
}

if (p.parseOptionalAttrDict(result.attributes))
return failure();

// If the op has no results, the `-> type($results)` is absent.
Expand All @@ -236,25 +316,60 @@ ParseResult KernelOp::parse(OpAsmParser &p, OperationState &result) {
return success();
}

// Parse
// $name custom<KernelArgumentList>(type($arguments), $arguments) attr-dict
// Print
// $name custom<KernelArgumentList>(type($arguments), $arguments)
// (instantiation_args custom<InstantiationArgs>)? attr-dict
// `->` type($results)
void KernelOp::print(OpAsmPrinter &p) {
p << ' ';
p << getNameAttr();
p << ' ';
printKernelArgumentList(p, getOperandTypes(), getOperands());
p << ' ';
SmallVector<StringRef> elidedAttrs = {"name"};
auto instantiationArgs = getInstantiationArgs();
auto instantiationArgNames = getInstantiationArgNames();
if (instantiationArgs != std::nullopt) {
printKernelInstantiationArgs(p, instantiationArgs->getValue(),
(instantiationArgNames == std::nullopt)
? ArrayRef<Attribute>()
: instantiationArgNames->getValue());
p << ' ';
}

SmallVector<StringRef> elidedAttrs = {"name", "instantiation_args",
"instantiation_arg_names"};
p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs);
if (getOperation()->getAttrs().size() > elidedAttrs.size())
if (llvm::any_of(
getOperation()->getAttrs(), [&elidedAttrs](NamedAttribute a) {
auto name = a.getName();
return llvm::any_of(elidedAttrs, [&name](StringRef elidedName) {
return name == elidedName;
});
}))
p << ' ';
if (getNumResults()) {
p << "-> ";
p << getResultTypes();
}
}

LogicalResult KernelOp::verify() {
if (getInstantiationArgNames().has_value()) {
if (!getInstantiationArgs().has_value()) {
return emitOpError(
"cannot have instantiation arg names without instantiation args");
}
if (!(getInstantiationArgNamesAttr().empty() ||
getInstantiationArgNamesAttr().size() ==
getInstantiationArgsAttr().size())) {
return emitOpError("instantiation arg names must be either empty or as "
"long as instantiation args");
}
}

return success();
}

#define GET_OP_CLASSES
#include "xten/Dialect/XTenNN/IR/XTenNNOps.cpp.inc"

Expand All @@ -266,9 +381,7 @@ ParseResult SubgraphOp::parse(OpAsmParser &p, OperationState &result) {
return parseEnclaveOp(p, result);
}

void SubgraphOp::print(OpAsmPrinter &p) {
printEnclaveOp(p, *this);
}
void SubgraphOp::print(OpAsmPrinter &p) { printEnclaveOp(p, *this); }

LogicalResult SubgraphOp::verify() {
Block *optBody = this->getOptionalEnclaveBody();
Expand Down
6 changes: 6 additions & 0 deletions test/Dialect/XTenNN/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ func.func @kernel(%arg0: tensor<2xi64>, %arg1 : tensor<4xi64>) {
// CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>) {attr = 4 : i32} -> tensor<2xi64>
%d:2 = xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64>
// CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64>
%e = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args ["N" = 42 : i32, "idx" = 56 : index] {attr = 4 : i32} -> tensor<2xi64>
// CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args ["N" = 42 : i32, "idx" = 56 : index] {attr = 4 : i32} -> tensor<2xi64>
%f = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args [42 : i32, 56 : index] {attr = 4 : i32} -> tensor<2xi64>
// CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args [42 : i32, 56 : index] {attr = 4 : i32} -> tensor<2xi64>
%g = xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args [42, 56, 1.0] {attr = 4 : i32} -> tensor<2xi64>
// CHECK: xten_nn.kernel "matmul" (%arg0 : tensor<2xi64>) instantiation_args [42 : i64, 56 : i64, 1.000000e+00 : f64] {attr = 4 : i32} -> tensor<2xi64>
return
}

Expand Down
25 changes: 25 additions & 0 deletions test/Dialect/XTenNN/ops_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,31 @@ func.func @kernel_missing_result(%arg0: i8, %arg1: i8) {
xten_nn.kernel "myKernel" () ->
}

// -----

func.func @kernel_instantiation_list_different_length(%arg0: i8, %arg1: i8) {
// expected-error@+1 {{instantiation arg names must be either empty or as long as instantiation args}}
%x = xten_nn.kernel "myKernel" () instantiation_args ["N" = 42 : i32, 51 : index] -> i32
return
}

// -----

func.func @kernel_instantiation_list_non_empty(%arg0: i8, %arg1: i8) {
// expected-error@+1 {{cannot have instantiation arg names without instantiation args}}
%x = xten_nn.kernel "myKernel" () {instantiation_arg_names = ["N"]} -> i32
return
}

// -----

func.func @kernel_instantiation_list_non_quoted(%arg0: i8, %arg1: i8) {
// expected-error@+1 {{expected ']'}}
%x = xten_nn.kernel "myKernel" () instantiation_args [N = 42 : i32] -> i32
return
}


// -----

func.func @topk_wrong_output_shape(%arg0: tensor<10x10xf32>) {
Expand Down

0 comments on commit a08963c

Please sign in to comment.