Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add topk operator #66

Merged
merged 11 commits into from
Aug 21, 2024
26 changes: 14 additions & 12 deletions lib/Dialect/XTenNN/IR/XTenNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
Expand Down Expand Up @@ -530,7 +531,11 @@ LogicalResult amd::xten_nn::ResizeOp::verify() {
return success();
}

std::optional<uint64_t> getConstantK(Operation *op) {
std::optional<uint64_t> getConstantK(Value k) {
auto *op = k.getDefiningOp();
if (!op) {
return {};
}
auto constantOp = dyn_cast<arith::ConstantOp>(op);
if (!constantOp)
return {};
Expand All @@ -551,24 +556,21 @@ LogicalResult TopK::inferReturnTypeComponents(
if (axis >= (uint64_t)inTy.getRank()) {
return emitOptionalError(location, "expected axis <= rank of input");
}
auto dimSize = inTy.getDimSize(axis);
auto k = getConstantK(adaptor.getK().getDefiningOp());

if (!k) {
return emitOptionalError(location, "expected constant k");
auto dimSize = inTy.getDimSize(axis);
auto k = getConstantK(adaptor.getK());
if (k) {
if ((uint64_t)dimSize < *k) {
return emitOptionalError(location, "expected k <= dimension size");
}
}

if (dimSize < 0) {
// TODO: Support negative dimSize
return emitOptionalError(location, "expected positive axis");
}

if ((uint64_t)dimSize < *k) {
return emitOptionalError(location, "expected k <= dimension size");
return emitOptionalError(location, "expected positive dimSize");
}

SmallVector<int64_t> resultShape{inTy.getShape()};
resultShape[axis] = *k;
resultShape[axis] = k ? *k : ShapedType::kDynamic;

inferredReturnShapes.push_back(
ShapedTypeComponents(resultShape, inTy.getElementType()));
Expand Down
9 changes: 9 additions & 0 deletions test/Dialect/XTenNN/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,12 @@ func.func @topk(%arg0: tensor<10x8xf32>) {
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>, %[[C7]] : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64>
return
}

// -----

// CHECK-LABEL: topk_arg
func.func @topk_arg(%arg0: tensor<10x8xf32>, %k: i64) {
xten_nn.topk(%arg0 : tensor<10x8xf32>, %k : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x?xf32>, tensor<10x?xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>, %arg1 : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x?xf32>, tensor<10x?xi64>
return
}
Loading