From edc95ff281b4f69ca8634e0c116b585baf714b76 Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Tue, 20 Aug 2024 11:36:45 +0100 Subject: [PATCH] Add the Indice to the outputs Keep consistency with the onnx.topk --- include/xten/Dialect/XTenNN/IR/XTenNNOps.td | 11 +-- lib/Dialect/XTenNN/IR/XTenNNOps.cpp | 47 ++++++++----- test/Dialect/XTenNN/ops.mlir | 10 +-- test/Dialect/XTenNN/ops_invalid.mlir | 76 +++------------------ 4 files changed, 46 insertions(+), 98 deletions(-) diff --git a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td index a42ee202..02a12a92 100644 --- a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td +++ b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td @@ -517,10 +517,7 @@ def XtenNN_TopK: XTenNN_Op<"topk", [ Pure, TosaExtension]> { let summary = "Calculate the topk"; let description = [{ - Follows the specification of ONNX TopK at opset 11, except - that - - this operation does not return the `Indices` output. - - the input `K` must be a compile-time constant + Follows the specification of ONNX TopK at opset 11 }]; let arguments = (ins AnyRankedTensor:$input, @@ -529,11 +526,9 @@ def XtenNN_TopK: XTenNN_Op<"topk", [ I1Attr:$largest, I1Attr:$sorted ); - let results = (outs - AnyRankedTensor:$output - ); + let results = (outs AnyRankedTensor:$output, AnyRankedTensor:$indices); - let assemblyFormat = [{ `(`$input `:` type($input)`)` attr-dict `->` type($output) }]; + let assemblyFormat = [{ `(`$input `:` type($input)`)` attr-dict `->` type($output) `,` type($indices) }]; } diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index f6946fae..53e8d8c7 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -96,9 +96,7 @@ static ParseResult parseCaptures(OpAsmParser &p, /// See parseCaptures() for more details. static void printCaptures(OpAsmPrinter &p, ValueRange srcs) { p << '('; - llvm::interleaveComma(srcs, p, [&](auto src) { - printCapture(p, src); - }); + llvm::interleaveComma(srcs, p, [&](auto src) { printCapture(p, src); }); p << ')'; } @@ -172,7 +170,6 @@ static void printEnclaveOp(OpAsmPrinter &p, EnclaveOp op) { }; } - //===----------------------------------------------------------------------===// // KernelOp //===----------------------------------------------------------------------===// @@ -286,8 +283,7 @@ LogicalResult SubgraphOp::verify() { } // The type of the arguments must match the types of the block arguments - for (auto [idx, argType] : - enumerate(optBody->getArgumentTypes())) { + for (auto [idx, argType] : enumerate(optBody->getArgumentTypes())) { if (this->getCapture(idx).getType() != argType) { return this->emitOpError() << "type of operand #" << idx << " (" @@ -349,11 +345,12 @@ OpFoldResult amd::xten_nn::QuantizeOp::fold(FoldAdaptor adaptor) { } OpFoldResult amd::xten_nn::GroupQuantizeOp::fold(FoldAdaptor adaptor) { - // Fold away cases where a xten_nn.group_quantize is preceeded by xten_nn.group_dequantize - // that uses the same shift factor and has same types. + // Fold away cases where a xten_nn.group_quantize is preceeded by + // xten_nn.group_dequantize that uses the same shift factor and has same + // types. - auto dequantizeOp = - dyn_cast_or_null(getInput().getDefiningOp()); + auto dequantizeOp = dyn_cast_or_null( + getInput().getDefiningOp()); if (!dequantizeOp) return {}; @@ -412,19 +409,25 @@ LogicalResult amd::xten_nn::GroupQuantizeOp::verify() { auto quantsShape = cast(getQuants().getType()).getShape(); if (inputShape != quantsShape) { - return emitOpError() << "input and quants must have the same shape (" << inputShape << " v " << quantsShape << ")"; + return emitOpError() << "input and quants must have the same shape (" + << inputShape << " v " << quantsShape << ")"; } if (scalesShape != zerosShape) { - return emitOpError() << "scales and zeros must have the same shape (" << scalesShape << " v " << zerosShape << ")"; + return emitOpError() << "scales and zeros must have the same shape (" + << scalesShape << " v " << zerosShape << ")"; } if (scalesShape.back() != 1) { - return emitOpError() << "groups needs to be expressed in the innermost dimension of scales vs quants (" << scalesShape.back() << ")" ; + return emitOpError() << "groups needs to be expressed in the innermost " + "dimension of scales vs quants (" + << scalesShape.back() << ")"; } if (scalesShape.drop_back() != quantsShape.drop_back()) { - return emitOpError() << "scales and quants must have the same shape except for the innermost dimension (" << scalesShape << " v " << quantsShape << ")"; + return emitOpError() << "scales and quants must have the same shape except " + "for the innermost dimension (" + << scalesShape << " v " << quantsShape << ")"; } // TODO validate: @@ -441,19 +444,25 @@ LogicalResult amd::xten_nn::GroupDequantizeOp::verify() { auto quantsShape = cast(getQuants().getType()).getShape(); if (outputShape != quantsShape) { - return emitOpError() << "output and quants must have the same shape (" << outputShape << " v " << quantsShape << ")"; + return emitOpError() << "output and quants must have the same shape (" + << outputShape << " v " << quantsShape << ")"; } if (scalesShape != zerosShape) { - return emitOpError() << "scales and zeros must have the same shape (" << scalesShape << " v " << zerosShape << ")"; + return emitOpError() << "scales and zeros must have the same shape (" + << scalesShape << " v " << zerosShape << ")"; } if (scalesShape.back() != 1) { - return emitOpError() << "groups needs to be expressed in the innermost dimension of scales vs quants (" << scalesShape.back() << ")" ; + return emitOpError() << "groups needs to be expressed in the innermost " + "dimension of scales vs quants (" + << scalesShape.back() << ")"; } if (scalesShape.drop_back() != quantsShape.drop_back()) { - return emitOpError() << "scales and quants must have the same shape except for the innermost dimension (" << scalesShape << " v " << quantsShape << ")"; + return emitOpError() << "scales and quants must have the same shape except " + "for the innermost dimension (" + << scalesShape << " v " << quantsShape << ")"; } // TODO validate: @@ -540,5 +549,7 @@ LogicalResult TopK::inferReturnTypeComponents( inferredReturnShapes.push_back( ShapedTypeComponents(resultShape, inTy.getElementType())); + inferredReturnShapes.push_back( + ShapedTypeComponents(resultShape, IntegerType::get(context, 64))); return success(); } diff --git a/test/Dialect/XTenNN/ops.mlir b/test/Dialect/XTenNN/ops.mlir index 7f129583..62ccf09c 100644 --- a/test/Dialect/XTenNN/ops.mlir +++ b/test/Dialect/XTenNN/ops.mlir @@ -41,9 +41,9 @@ func.func @kernel(%arg0: tensor<2xi64>, %arg1 : tensor<4xi64>) { // CHECK-LABEL: topk func.func @topk(%arg0: tensor<10x8xf32>) { - xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x8xf32> - // CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x8xf32> - xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 1 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x7xf32> - // CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 1 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x7xf32> + xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64> + // CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64> + xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 1 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64> + // CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 1 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64> return -} +} \ No newline at end of file diff --git a/test/Dialect/XTenNN/ops_invalid.mlir b/test/Dialect/XTenNN/ops_invalid.mlir index 5d83f8b7..f8d47764 100644 --- a/test/Dialect/XTenNN/ops_invalid.mlir +++ b/test/Dialect/XTenNN/ops_invalid.mlir @@ -3,77 +3,19 @@ // RUN: aten-opt %s -split-input-file -verify-diagnostics -func.func @atan2_int(%arg0: tensor<1x10xi4>, %arg1: tensor<1x10xi4>) -> tensor<1x10xi4> { - // expected-error@+1 {{op operand #0 must be tensor of floating-point values, but got 'tensor<1x10xi4>'}} - %0 = xten_nn.atan2 %arg0, %arg1 : (tensor<1x10xi4>, tensor<1x10xi4>) -> tensor<1x10xi4> - return %0 : tensor<1x10xi4> -} - -// ----- - -func.func @elu_int(%arg0: tensor<1x10xi4>) -> tensor<1x10xi4> { - // expected-error@+1 {{op operand #0 must be tensor of floating-point values, but got 'tensor<1x10xi4>'}} - %0 = xten_nn.elu %arg0 { alpha = 1.000000e-00 : f32} : (tensor<1x10xi4>) -> tensor<1x10xi4> - return %0 : tensor<1x10xi4> -} - -// ----- - -func.func @mish_int(%arg0: tensor<1x10xi4>) -> tensor<1x10xi4> { - // expected-error@+1 {{op operand #0 must be tensor of floating-point values, but got 'tensor<1x10xi4>'}} - %0 = xten_nn.mish %arg0 : (tensor<1x10xi4>) -> tensor<1x10xi4> - return %0 : tensor<1x10xi4> -} - -// ----- - -func.func @kernel_missing_parenthesis() { - // expected-error@+1 {{expected '('}} - %a = xten_nn.kernel "myKernel" -> tensor<2xi64> -} - -// ----- - -func.func @kernel_missing_colon(%arg0: i8, %arg1: i8) { - // expected-error@+1 {{expected ':`, (argument format is val : type)}} - %a = xten_nn.kernel "myKernel" (%arg0, %arg1) -> tensor<2xi64> -} - -// ----- - -func.func @kernel_missing_type(%arg0: i8, %arg1: i8) { - // expected-error@+1 {{expected non-function type}} - %a = xten_nn.kernel "myKernel" (%arg0 : ) -> tensor<2xi64> -} - -// ----- - -func.func @kernel_trailing_comma(%arg0: i8) { - // expected-error@+1 {{expected SSA operand}} - %a = xten_nn.kernel "myKernel" (%arg0 :i8, ) -> tensor<2xi64> -} - -// ----- - -func.func @kernel_missing_name() { - // expected-error@+1 {{'xten_nn.kernel' invalid kind of attribute specified}} - %b = xten_nn.kernel () -> tensor<2xi64> +func.func @topk_wrong_output_shape(%arg0: tensor<10x10xf32>) { + // expected-error@+2 {{failed to infer returned types}} + // expected-error@+1 {{'xten_nn.topk' op inferred type(s) 'tensor<7x10xf32>', 'tensor<7x10xi64>' are incompatible with return type(s) of operation 'tensor<1xf32>', 'tensor<1xi64>'}} + %a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<1xf32>, tensor<1xi64> return } // ----- -func.func @kernel_missing_result(%arg0: i8, %arg1: i8) { - // expected-error@+1 {{expected non-function type}} - xten_nn.kernel "myKernel" () -> -} - -// ----- - -func.func @topk_wrong_output_shape(%arg0: tensor<10x10xf32>) { +func.func @topk_wrong_indices_shape(%arg0: tensor<10x10xf32>) { // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{'xten_nn.topk' op inferred type(s) 'tensor<7x10xf32>' are incompatible with return type(s) of operation 'tensor<1xf32>'}} - %a = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<1xf32> + // expected-error@+1 {{'xten_nn.topk' op inferred type(s) 'tensor<7x10xf32>', 'tensor<7x10xi64>' are incompatible with return type(s) of operation 'tensor<7x10xf32>', 'tensor<7x10xf32>'}} + %a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x10xf32>, tensor<7x10xf32> return } @@ -82,7 +24,7 @@ func.func @topk_wrong_output_shape(%arg0: tensor<10x10xf32>) { func.func @topk_wrong_axis(%arg0: tensor<10x10xf32>) { // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{expected axis <= rank of input}} - %a = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 3 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x10xf32> + %a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 3 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64> return } @@ -91,6 +33,6 @@ func.func @topk_wrong_axis(%arg0: tensor<10x10xf32>) { func.func @topk_large_k(%arg0: tensor<10x10xf32>) { // expected-error@+2 {{failed to infer returned types}} // expected-error@+1 {{expected k <= dimension size}} - %a = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 100 : i64, largest = true, sorted = true} -> tensor<10x10xf32> + %a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 100 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64> return }