Skip to content

Commit

Permalink
Add the Indice to the outputs
Browse files Browse the repository at this point in the history
Keep consistency with the onnx.topk
  • Loading branch information
josel-amd committed Aug 20, 2024
1 parent a2b0fd6 commit bb72770
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 36 deletions.
11 changes: 3 additions & 8 deletions include/xten/Dialect/XTenNN/IR/XTenNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -517,10 +517,7 @@ def XtenNN_TopK: XTenNN_Op<"topk", [
Pure, TosaExtension]> {
let summary = "Calculate the topk";
let description = [{
Follows the specification of ONNX TopK at opset 11, except
that
- this operation does not return the `Indices` output.
- the input `K` must be a compile-time constant
Follows the specification of ONNX TopK at opset 11
}];
let arguments = (ins
AnyRankedTensor:$input,
Expand All @@ -529,11 +526,9 @@ def XtenNN_TopK: XTenNN_Op<"topk", [
I1Attr:$largest,
I1Attr:$sorted
);
let results = (outs
AnyRankedTensor:$output
);
let results = (outs AnyRankedTensor:$output, AnyRankedTensor:$indices);

let assemblyFormat = [{ `(`$input `:` type($input)`)` attr-dict `->` type($output) }];
let assemblyFormat = [{ `(`$input `:` type($input)`)` attr-dict `->` type($output) `,` type($indices) }];
}


Expand Down
47 changes: 29 additions & 18 deletions lib/Dialect/XTenNN/IR/XTenNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,7 @@ static ParseResult parseCaptures(OpAsmParser &p,
/// See parseCaptures() for more details.
static void printCaptures(OpAsmPrinter &p, ValueRange srcs) {
p << '(';
llvm::interleaveComma(srcs, p, [&](auto src) {
printCapture(p, src);
});
llvm::interleaveComma(srcs, p, [&](auto src) { printCapture(p, src); });
p << ')';
}

Expand Down Expand Up @@ -172,7 +170,6 @@ static void printEnclaveOp(OpAsmPrinter &p, EnclaveOp op) {
};
}


//===----------------------------------------------------------------------===//
// KernelOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -286,8 +283,7 @@ LogicalResult SubgraphOp::verify() {
}

// The type of the arguments must match the types of the block arguments
for (auto [idx, argType] :
enumerate(optBody->getArgumentTypes())) {
for (auto [idx, argType] : enumerate(optBody->getArgumentTypes())) {
if (this->getCapture(idx).getType() != argType) {
return this->emitOpError()
<< "type of operand #" << idx << " ("
Expand Down Expand Up @@ -349,11 +345,12 @@ OpFoldResult amd::xten_nn::QuantizeOp::fold(FoldAdaptor adaptor) {
}

OpFoldResult amd::xten_nn::GroupQuantizeOp::fold(FoldAdaptor adaptor) {
// Fold away cases where a xten_nn.group_quantize is preceeded by xten_nn.group_dequantize
// that uses the same shift factor and has same types.
// Fold away cases where a xten_nn.group_quantize is preceeded by
// xten_nn.group_dequantize that uses the same shift factor and has same
// types.

auto dequantizeOp =
dyn_cast_or_null<amd::xten_nn::GroupDequantizeOp>(getInput().getDefiningOp());
auto dequantizeOp = dyn_cast_or_null<amd::xten_nn::GroupDequantizeOp>(
getInput().getDefiningOp());
if (!dequantizeOp)
return {};

Expand Down Expand Up @@ -412,19 +409,25 @@ LogicalResult amd::xten_nn::GroupQuantizeOp::verify() {
auto quantsShape = cast<ShapedType>(getQuants().getType()).getShape();

if (inputShape != quantsShape) {
return emitOpError() << "input and quants must have the same shape (" << inputShape << " v " << quantsShape << ")";
return emitOpError() << "input and quants must have the same shape ("
<< inputShape << " v " << quantsShape << ")";
}

if (scalesShape != zerosShape) {
return emitOpError() << "scales and zeros must have the same shape (" << scalesShape << " v " << zerosShape << ")";
return emitOpError() << "scales and zeros must have the same shape ("
<< scalesShape << " v " << zerosShape << ")";
}

if (scalesShape.back() != 1) {
return emitOpError() << "groups needs to be expressed in the innermost dimension of scales vs quants (" << scalesShape.back() << ")" ;
return emitOpError() << "groups needs to be expressed in the innermost "
"dimension of scales vs quants ("
<< scalesShape.back() << ")";
}

if (scalesShape.drop_back() != quantsShape.drop_back()) {
return emitOpError() << "scales and quants must have the same shape except for the innermost dimension (" << scalesShape << " v " << quantsShape << ")";
return emitOpError() << "scales and quants must have the same shape except "
"for the innermost dimension ("
<< scalesShape << " v " << quantsShape << ")";
}

// TODO validate:
Expand All @@ -441,19 +444,25 @@ LogicalResult amd::xten_nn::GroupDequantizeOp::verify() {
auto quantsShape = cast<ShapedType>(getQuants().getType()).getShape();

if (outputShape != quantsShape) {
return emitOpError() << "output and quants must have the same shape (" << outputShape << " v " << quantsShape << ")";
return emitOpError() << "output and quants must have the same shape ("
<< outputShape << " v " << quantsShape << ")";
}

if (scalesShape != zerosShape) {
return emitOpError() << "scales and zeros must have the same shape (" << scalesShape << " v " << zerosShape << ")";
return emitOpError() << "scales and zeros must have the same shape ("
<< scalesShape << " v " << zerosShape << ")";
}

if (scalesShape.back() != 1) {
return emitOpError() << "groups needs to be expressed in the innermost dimension of scales vs quants (" << scalesShape.back() << ")" ;
return emitOpError() << "groups needs to be expressed in the innermost "
"dimension of scales vs quants ("
<< scalesShape.back() << ")";
}

if (scalesShape.drop_back() != quantsShape.drop_back()) {
return emitOpError() << "scales and quants must have the same shape except for the innermost dimension (" << scalesShape << " v " << quantsShape << ")";
return emitOpError() << "scales and quants must have the same shape except "
"for the innermost dimension ("
<< scalesShape << " v " << quantsShape << ")";
}

// TODO validate:
Expand Down Expand Up @@ -540,5 +549,7 @@ LogicalResult TopK::inferReturnTypeComponents(

inferredReturnShapes.push_back(
ShapedTypeComponents(resultShape, inTy.getElementType()));
inferredReturnShapes.push_back(
ShapedTypeComponents(resultShape, IntegerType::get(context, 64)));
return success();
}
10 changes: 5 additions & 5 deletions test/Dialect/XTenNN/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ func.func @kernel(%arg0: tensor<2xi64>, %arg1 : tensor<4xi64>) {

// CHECK-LABEL: topk
func.func @topk(%arg0: tensor<10x8xf32>) {
xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x8xf32>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x8xf32>
xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 1 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x7xf32>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 1 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x7xf32>
xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64>
xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 1 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 1 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64>
return
}
}
19 changes: 14 additions & 5 deletions test/Dialect/XTenNN/ops_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,17 @@ func.func @kernel_missing_result(%arg0: i8, %arg1: i8) {

func.func @topk_wrong_output_shape(%arg0: tensor<10x10xf32>) {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'xten_nn.topk' op inferred type(s) 'tensor<7x10xf32>' are incompatible with return type(s) of operation 'tensor<1xf32>'}}
%a = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<1xf32>
// expected-error@+1 {{'xten_nn.topk' op inferred type(s) 'tensor<7x10xf32>', 'tensor<7x10xi64>' are incompatible with return type(s) of operation 'tensor<1xf32>', 'tensor<1xi64>'}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<1xf32>, tensor<1xi64>
return
}

// -----

func.func @topk_wrong_indices_shape(%arg0: tensor<10x10xf32>) {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'xten_nn.topk' op inferred type(s) 'tensor<7x10xf32>', 'tensor<7x10xi64>' are incompatible with return type(s) of operation 'tensor<7x10xf32>', 'tensor<7x10xf32>'}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x10xf32>, tensor<7x10xf32>
return
}

Expand All @@ -82,7 +91,7 @@ func.func @topk_wrong_output_shape(%arg0: tensor<10x10xf32>) {
func.func @topk_wrong_axis(%arg0: tensor<10x10xf32>) {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{expected axis <= rank of input}}
%a = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 3 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x10xf32>
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 3 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
return
}

Expand All @@ -91,6 +100,6 @@ func.func @topk_wrong_axis(%arg0: tensor<10x10xf32>) {
func.func @topk_large_k(%arg0: tensor<10x10xf32>) {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{expected k <= dimension size}}
%a = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 100 : i64, largest = true, sorted = true} -> tensor<10x10xf32>
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 100 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
return
}
}

0 comments on commit bb72770

Please sign in to comment.