Skip to content

Commit

Permalink
Add support for dynamic topk
Browse files Browse the repository at this point in the history
  • Loading branch information
josel-amd committed Aug 20, 2024
1 parent 5b11a93 commit 0f9a0c6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
7 changes: 2 additions & 5 deletions lib/Dialect/XTenNN/IR/XTenNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,23 +552,20 @@ LogicalResult TopK::inferReturnTypeComponents(
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {

auto inTy = cast<RankedTensorType>(adaptor.getInput().getType());

auto axis = adaptor.getAxis();
if (axis >= (uint64_t)inTy.getRank()) {
return emitOptionalError(location, "expected axis <= rank of input");
}

auto dimSize = inTy.getDimSize(axis);
auto k = getConstantK(adaptor.getK());
if (k) {
if (k && dimSize != ShapedType::kDynamic) {
if ((uint64_t)dimSize < *k) {
return emitOptionalError(location, "expected k <= dimension size");
}
}

if (dimSize < 0) {
return emitOptionalError(location, "expected positive dimSize");
}

SmallVector<int64_t> resultShape{inTy.getShape()};
resultShape[axis] = k ? *k : ShapedType::kDynamic;

Expand Down
9 changes: 9 additions & 0 deletions test/Dialect/XTenNN/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,12 @@ func.func @topk_arg(%arg0: tensor<10x8xf32>, %k: i64) {
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>, %arg1 : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<10x?xf32>, tensor<10x?xi64>
return
}

// -----

// CHECK-LABEL: topk_arg_dyn_in
func.func @topk_arg_dyn_in(%arg0: tensor<?x?xf32>, %k: i64) {
xten_nn.topk(%arg0 : tensor<?x?xf32>, %k : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<?x?xf32>, tensor<?x?xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<?x?xf32>, %arg1 : i64) {axis = 1 : i64, largest = true, sorted = true} -> tensor<?x?xf32>, tensor<?x?xi64>
return
}

0 comments on commit 0f9a0c6

Please sign in to comment.