Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add topk operator #66

Merged
merged 11 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/xten/Dialect/XTenNN/IR/XTenNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,32 @@ def XtenNN_SignOp: XTenNN_Op<"sign", [Pure, TosaExtension, ElementwiseUnary, Sam
let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) }];
}

def XtenNN_TopK: XTenNN_Op<"topk", [
InferTensorTypeAdaptor,
Pure, TosaExtension]> {
let summary = "Calculate the topk";
let description = [{
Follows the specification of ONNX TopK at opset 11
}];
let arguments = (ins
AnyTensor:$input,
I64:$k,
I64Attr:$axis,
I1Attr:$largest,
I1Attr:$sorted
);
let results = (outs AnyRankedTensor:$output, AnyRankedTensor:$indices);

let assemblyFormat = [{ `(`$input `:` type($input) `,` $k `:` type($k)`)` attr-dict `->` type($output) `,` type($indices) }];

let extraClassDeclaration = [{
/// Returns when two result types are compatible for this op; method used by
/// InferTypeOpInterface.
static bool isCompatibleReturnTypes(mlir::TypeRange l, mlir::TypeRange r);
}];
}


def XtenNN_ConvTransposeOp: XTenNN_Op<"ConvTranspose",[Pure, TosaExtension]> {
let summary = "Perform ConvTranspose operation";
let description = [{
Expand Down
102 changes: 84 additions & 18 deletions lib/Dialect/XTenNN/IR/XTenNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpImplementation.h"
Expand Down Expand Up @@ -96,9 +97,7 @@ static ParseResult parseCaptures(OpAsmParser &p,
/// See parseCaptures() for more details.
static void printCaptures(OpAsmPrinter &p, ValueRange srcs) {
p << '(';
llvm::interleaveComma(srcs, p, [&](auto src) {
printCapture(p, src);
});
llvm::interleaveComma(srcs, p, [&](auto src) { printCapture(p, src); });
p << ')';
}

Expand Down Expand Up @@ -172,7 +171,6 @@ static void printEnclaveOp(OpAsmPrinter &p, EnclaveOp op) {
};
}


//===----------------------------------------------------------------------===//
// KernelOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -286,8 +284,7 @@ LogicalResult SubgraphOp::verify() {
}

// The type of the arguments must match the types of the block arguments
for (auto [idx, argType] :
enumerate(optBody->getArgumentTypes())) {
for (auto [idx, argType] : enumerate(optBody->getArgumentTypes())) {
if (this->getCapture(idx).getType() != argType) {
return this->emitOpError()
<< "type of operand #" << idx << " ("
Expand Down Expand Up @@ -349,11 +346,12 @@ OpFoldResult amd::xten_nn::QuantizeOp::fold(FoldAdaptor adaptor) {
}

OpFoldResult amd::xten_nn::GroupQuantizeOp::fold(FoldAdaptor adaptor) {
// Fold away cases where a xten_nn.group_quantize is preceeded by xten_nn.group_dequantize
// that uses the same shift factor and has same types.
// Fold away cases where a xten_nn.group_quantize is preceeded by
// xten_nn.group_dequantize that uses the same shift factor and has same
// types.

auto dequantizeOp =
dyn_cast_or_null<amd::xten_nn::GroupDequantizeOp>(getInput().getDefiningOp());
auto dequantizeOp = dyn_cast_or_null<amd::xten_nn::GroupDequantizeOp>(
getInput().getDefiningOp());
if (!dequantizeOp)
return {};

Expand Down Expand Up @@ -412,19 +410,25 @@ LogicalResult amd::xten_nn::GroupQuantizeOp::verify() {
auto quantsShape = cast<ShapedType>(getQuants().getType()).getShape();

if (inputShape != quantsShape) {
return emitOpError() << "input and quants must have the same shape (" << inputShape << " v " << quantsShape << ")";
return emitOpError() << "input and quants must have the same shape ("
<< inputShape << " v " << quantsShape << ")";
}

if (scalesShape != zerosShape) {
return emitOpError() << "scales and zeros must have the same shape (" << scalesShape << " v " << zerosShape << ")";
return emitOpError() << "scales and zeros must have the same shape ("
<< scalesShape << " v " << zerosShape << ")";
}

if (scalesShape.back() != 1) {
return emitOpError() << "groups needs to be expressed in the innermost dimension of scales vs quants (" << scalesShape.back() << ")" ;
return emitOpError() << "groups needs to be expressed in the innermost "
"dimension of scales vs quants ("
<< scalesShape.back() << ")";
}

if (scalesShape.drop_back() != quantsShape.drop_back()) {
return emitOpError() << "scales and quants must have the same shape except for the innermost dimension (" << scalesShape << " v " << quantsShape << ")";
return emitOpError() << "scales and quants must have the same shape except "
"for the innermost dimension ("
<< scalesShape << " v " << quantsShape << ")";
}

// TODO validate:
Expand All @@ -441,19 +445,25 @@ LogicalResult amd::xten_nn::GroupDequantizeOp::verify() {
auto quantsShape = cast<ShapedType>(getQuants().getType()).getShape();

if (outputShape != quantsShape) {
return emitOpError() << "output and quants must have the same shape (" << outputShape << " v " << quantsShape << ")";
return emitOpError() << "output and quants must have the same shape ("
<< outputShape << " v " << quantsShape << ")";
}

if (scalesShape != zerosShape) {
return emitOpError() << "scales and zeros must have the same shape (" << scalesShape << " v " << zerosShape << ")";
return emitOpError() << "scales and zeros must have the same shape ("
<< scalesShape << " v " << zerosShape << ")";
}

if (scalesShape.back() != 1) {
return emitOpError() << "groups needs to be expressed in the innermost dimension of scales vs quants (" << scalesShape.back() << ")" ;
return emitOpError() << "groups needs to be expressed in the innermost "
"dimension of scales vs quants ("
<< scalesShape.back() << ")";
}

if (scalesShape.drop_back() != quantsShape.drop_back()) {
return emitOpError() << "scales and quants must have the same shape except for the innermost dimension (" << scalesShape << " v " << quantsShape << ")";
return emitOpError() << "scales and quants must have the same shape except "
"for the innermost dimension ("
<< scalesShape << " v " << quantsShape << ")";
}

// TODO validate:
Expand Down Expand Up @@ -519,3 +529,59 @@ LogicalResult amd::xten_nn::ResizeOp::verify() {

return success();
}

std::optional<uint64_t> getConstantK(Value k) {
auto *op = k.getDefiningOp();
if (!op) {
return {};
}
auto constantOp = dyn_cast<arith::ConstantOp>(op);
if (!constantOp)
return {};
auto intAttr = dyn_cast<IntegerAttr>(constantOp.getValue());
if (!intAttr)
return {};
return (uint64_t)
intAttr.getInt(); // Always positive by definition of onnx.topk
}

LogicalResult TopK::inferReturnTypeComponents(
MLIRContext *context, std::optional<Location> location,
TopK::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {

auto inTy = cast<RankedTensorType>(adaptor.getInput().getType());

auto axis = adaptor.getAxis();
if (axis >= (uint64_t)inTy.getRank()) {
return emitOptionalError(location, "expected axis <= rank of input");
}

auto dimSize = inTy.getDimSize(axis);
auto k = getConstantK(adaptor.getK());
// If both k and dim are known statically, we can check that k <= dim
if (k && dimSize != ShapedType::kDynamic) {
if ((uint64_t)dimSize <= *k) {
return emitOptionalError(location, "expected k <= dimension size");
}
}

SmallVector<int64_t> resultShape{inTy.getShape()};
resultShape[axis] = k ? *k : ShapedType::kDynamic;

inferredReturnShapes.push_back(
ShapedTypeComponents(resultShape, inTy.getElementType()));
inferredReturnShapes.push_back(
ShapedTypeComponents(resultShape, IntegerType::get(context, 64)));
return success();
}

bool TopK::isCompatibleReturnTypes(mlir::TypeRange l, mlir::TypeRange r) {
if (l.size() != r.size() || l.size() != 2)
return false;

auto sameElementType =
getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]) &&
getElementTypeOrSelf(l[1]) == getElementTypeOrSelf(r[1]);
return sameElementType && succeeded(verifyCompatibleShapes(l, r));
}
42 changes: 42 additions & 0 deletions test/Dialect/XTenNN/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,45 @@ func.func @kernel(%arg0: tensor<2xi64>, %arg1 : tensor<4xi64>) {
// CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64>
return
}

// -----

// CHECK-LABEL: topk
func.func @topk(%arg0: tensor<10x8xf32>) {
%k = arith.constant 7 : i64
// CHECK: %[[C7:.*]] = arith.constant 7 : i64
xten_nn.topk(%arg0 : tensor<10x8xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>, %[[C7]] : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64>
xten_nn.topk(%arg0 : tensor<10x8xf32>, %k : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>, %[[C7]] : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64>
return
}

// -----

// CHECK-LABEL: topk_arg
func.func @topk_arg(%arg0: tensor<10x8xf32>, %k: i64) {
xten_nn.topk(%arg0 : tensor<10x8xf32>, %k : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x?xf32>, tensor<10x?xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>, %arg1 : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x?xf32>, tensor<10x?xi64>
return
}

// -----

// Make sure that the topk verification does not fail if the result type is
// static even though it cannot be statically infered due to the dynamic k
// CHECK-LABEL: topk_arg_type_inference
func.func @topk_arg_type_inference(%arg0: tensor<10x8xf32>, %k: i64) {
xten_nn.topk(%arg0 : tensor<10x8xf32>, %k : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x3xf32>, tensor<10x3xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>, %arg1 : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x3xf32>, tensor<10x3xi64>
return
}

// -----

// CHECK-LABEL: topk_arg_dyn_in
func.func @topk_arg_dyn_in(%arg0: tensor<?x?xf32>, %k: i64) {
xten_nn.topk(%arg0 : tensor<?x?xf32>, %k : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<?x?xf32>, tensor<?x?xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<?x?xf32>, %arg1 : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<?x?xf32>, tensor<?x?xi64>
return
}
40 changes: 40 additions & 0 deletions test/Dialect/XTenNN/ops_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,43 @@ func.func @kernel_missing_result(%arg0: i8, %arg1: i8) {
// expected-error@+1 {{expected non-function type}}
xten_nn.kernel "myKernel" () ->
}

// -----

func.func @topk_wrong_output_shape(%arg0: tensor<10x10xf32>) {
%k = arith.constant 7 : i64
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'xten_nn.topk' op inferred type(s) 'tensor<7x10xf32>', 'tensor<7x10xi64>' are incompatible with return type(s) of operation 'tensor<1xf32>', 'tensor<1xi64>'}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<1xf32>, tensor<1xi64>
return
}

// -----

func.func @topk_wrong_indices_shape(%arg0: tensor<10x10xf32>) {
%k = arith.constant 7 : i64
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'xten_nn.topk' op inferred type(s) 'tensor<7x10xf32>', 'tensor<7x10xi64>' are incompatible with return type(s) of operation 'tensor<7x10xf32>', 'tensor<7x10xf32>'}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<7x10xf32>, tensor<7x10xf32>
return
}

// -----

func.func @topk_wrong_axis(%arg0: tensor<10x10xf32>) {
%k = arith.constant 7 : i64
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{expected axis <= rank of input}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 3 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
return
}

// -----

func.func @topk_large_k(%arg0: tensor<10x10xf32>) {
%k = arith.constant 100 : i64
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{expected k <= dimension size}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>, %k : i64) {axis = 0 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
return
}