Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add topk operator #66

Merged
merged 11 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions include/xten/Dialect/XTenNN/IR/XTenNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,26 @@ def XtenNN_SignOp: XTenNN_Op<"sign", [Pure, TosaExtension, ElementwiseUnary, Sam
let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) }];
}

def XtenNN_TopK: XTenNN_Op<"topk", [
InferTensorTypeAdaptor,
Pure, TosaExtension]> {
let summary = "Calculate the topk";
let description = [{
Follows the specification of ONNX TopK at opset 11
}];
let arguments = (ins
AnyRankedTensor:$input,
I64Attr:$k,
I64Attr:$axis,
I1Attr:$largest,
I1Attr:$sorted
);
let results = (outs AnyRankedTensor:$output, AnyRankedTensor:$indices);

let assemblyFormat = [{ `(`$input `:` type($input)`)` attr-dict `->` type($output) `,` type($indices) }];
}


def XtenNN_ConvTransposeOp: XTenNN_Op<"ConvTranspose",[Pure, TosaExtension]> {
let summary = "Perform ConvTranspose operation";
let description = [{
Expand Down
70 changes: 52 additions & 18 deletions lib/Dialect/XTenNN/IR/XTenNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,7 @@ static ParseResult parseCaptures(OpAsmParser &p,
/// See parseCaptures() for more details.
static void printCaptures(OpAsmPrinter &p, ValueRange srcs) {
p << '(';
llvm::interleaveComma(srcs, p, [&](auto src) {
printCapture(p, src);
});
llvm::interleaveComma(srcs, p, [&](auto src) { printCapture(p, src); });
p << ')';
}

Expand Down Expand Up @@ -172,7 +170,6 @@ static void printEnclaveOp(OpAsmPrinter &p, EnclaveOp op) {
};
}


//===----------------------------------------------------------------------===//
// KernelOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -286,8 +283,7 @@ LogicalResult SubgraphOp::verify() {
}

// The type of the arguments must match the types of the block arguments
for (auto [idx, argType] :
enumerate(optBody->getArgumentTypes())) {
for (auto [idx, argType] : enumerate(optBody->getArgumentTypes())) {
if (this->getCapture(idx).getType() != argType) {
return this->emitOpError()
<< "type of operand #" << idx << " ("
Expand Down Expand Up @@ -349,11 +345,12 @@ OpFoldResult amd::xten_nn::QuantizeOp::fold(FoldAdaptor adaptor) {
}

OpFoldResult amd::xten_nn::GroupQuantizeOp::fold(FoldAdaptor adaptor) {
// Fold away cases where a xten_nn.group_quantize is preceeded by xten_nn.group_dequantize
// that uses the same shift factor and has same types.
// Fold away cases where a xten_nn.group_quantize is preceeded by
// xten_nn.group_dequantize that uses the same shift factor and has same
// types.

auto dequantizeOp =
dyn_cast_or_null<amd::xten_nn::GroupDequantizeOp>(getInput().getDefiningOp());
auto dequantizeOp = dyn_cast_or_null<amd::xten_nn::GroupDequantizeOp>(
getInput().getDefiningOp());
if (!dequantizeOp)
return {};

Expand Down Expand Up @@ -412,19 +409,25 @@ LogicalResult amd::xten_nn::GroupQuantizeOp::verify() {
auto quantsShape = cast<ShapedType>(getQuants().getType()).getShape();

if (inputShape != quantsShape) {
return emitOpError() << "input and quants must have the same shape (" << inputShape << " v " << quantsShape << ")";
return emitOpError() << "input and quants must have the same shape ("
<< inputShape << " v " << quantsShape << ")";
}

if (scalesShape != zerosShape) {
return emitOpError() << "scales and zeros must have the same shape (" << scalesShape << " v " << zerosShape << ")";
return emitOpError() << "scales and zeros must have the same shape ("
<< scalesShape << " v " << zerosShape << ")";
}

if (scalesShape.back() != 1) {
return emitOpError() << "groups needs to be expressed in the innermost dimension of scales vs quants (" << scalesShape.back() << ")" ;
return emitOpError() << "groups needs to be expressed in the innermost "
"dimension of scales vs quants ("
<< scalesShape.back() << ")";
}

if (scalesShape.drop_back() != quantsShape.drop_back()) {
return emitOpError() << "scales and quants must have the same shape except for the innermost dimension (" << scalesShape << " v " << quantsShape << ")";
return emitOpError() << "scales and quants must have the same shape except "
"for the innermost dimension ("
<< scalesShape << " v " << quantsShape << ")";
}

// TODO validate:
Expand All @@ -441,19 +444,25 @@ LogicalResult amd::xten_nn::GroupDequantizeOp::verify() {
auto quantsShape = cast<ShapedType>(getQuants().getType()).getShape();

if (outputShape != quantsShape) {
return emitOpError() << "output and quants must have the same shape (" << outputShape << " v " << quantsShape << ")";
return emitOpError() << "output and quants must have the same shape ("
<< outputShape << " v " << quantsShape << ")";
}

if (scalesShape != zerosShape) {
return emitOpError() << "scales and zeros must have the same shape (" << scalesShape << " v " << zerosShape << ")";
return emitOpError() << "scales and zeros must have the same shape ("
<< scalesShape << " v " << zerosShape << ")";
}

if (scalesShape.back() != 1) {
return emitOpError() << "groups needs to be expressed in the innermost dimension of scales vs quants (" << scalesShape.back() << ")" ;
return emitOpError() << "groups needs to be expressed in the innermost "
"dimension of scales vs quants ("
<< scalesShape.back() << ")";
}

if (scalesShape.drop_back() != quantsShape.drop_back()) {
return emitOpError() << "scales and quants must have the same shape except for the innermost dimension (" << scalesShape << " v " << quantsShape << ")";
return emitOpError() << "scales and quants must have the same shape except "
"for the innermost dimension ("
<< scalesShape << " v " << quantsShape << ")";
}

// TODO validate:
Expand Down Expand Up @@ -519,3 +528,28 @@ LogicalResult amd::xten_nn::ResizeOp::verify() {

return success();
}

LogicalResult TopK::inferReturnTypeComponents(
MLIRContext *context, std::optional<Location> location,
TopK::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {

auto inTy = cast<RankedTensorType>(adaptor.getInput().getType());
auto axis = adaptor.getAxis();
if (axis >= (uint64_t)inTy.getRank()) {
return emitOptionalError(location, "expected axis <= rank of input");
}
auto dimSize = inTy.getDimSize(axis);
if ((uint64_t)dimSize < adaptor.getK()) {
return emitOptionalError(location, "expected k <= dimension size");
}

SmallVector<int64_t> resultShape{inTy.getShape()};
resultShape[axis] = adaptor.getK();

inferredReturnShapes.push_back(
ShapedTypeComponents(resultShape, inTy.getElementType()));
inferredReturnShapes.push_back(
ShapedTypeComponents(resultShape, IntegerType::get(context, 64)));
return success();
}
11 changes: 11 additions & 0 deletions test/Dialect/XTenNN/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,14 @@ func.func @kernel(%arg0: tensor<2xi64>, %arg1 : tensor<4xi64>) {
// CHECK: xten_nn.kernel "myKernel" (%arg0 : tensor<2xi64>, %arg1 : tensor<4xi64>) -> tensor<2xi64>, tensor<1xi64>
return
}

// -----

// CHECK-LABEL: topk
func.func @topk(%arg0: tensor<10x8xf32>) {
xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x8xf32>, tensor<7x8xi64>
xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 1 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64>
// CHECK: xten_nn.topk(%arg0 : tensor<10x8xf32>) {axis = 1 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x7xf32>, tensor<10x7xi64>
return
}
36 changes: 36 additions & 0 deletions test/Dialect/XTenNN/ops_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,39 @@ func.func @kernel_missing_result(%arg0: i8, %arg1: i8) {
// expected-error@+1 {{expected non-function type}}
xten_nn.kernel "myKernel" () ->
}

// -----

func.func @topk_wrong_output_shape(%arg0: tensor<10x10xf32>) {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'xten_nn.topk' op inferred type(s) 'tensor<7x10xf32>', 'tensor<7x10xi64>' are incompatible with return type(s) of operation 'tensor<1xf32>', 'tensor<1xi64>'}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<1xf32>, tensor<1xi64>
return
}

// -----

func.func @topk_wrong_indices_shape(%arg0: tensor<10x10xf32>) {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{'xten_nn.topk' op inferred type(s) 'tensor<7x10xf32>', 'tensor<7x10xi64>' are incompatible with return type(s) of operation 'tensor<7x10xf32>', 'tensor<7x10xf32>'}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<7x10xf32>, tensor<7x10xf32>
return
}

// -----

func.func @topk_wrong_axis(%arg0: tensor<10x10xf32>) {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{expected axis <= rank of input}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 3 : i64, k = 7 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
return
}

// -----

func.func @topk_large_k(%arg0: tensor<10x10xf32>) {
// expected-error@+2 {{failed to infer returned types}}
// expected-error@+1 {{expected k <= dimension size}}
%a, %b = xten_nn.topk(%arg0 : tensor<10x10xf32>) {axis = 0 : i64, k = 100 : i64, largest = true, sorted = true} -> tensor<10x10xf32>, tensor<1xi64>
return
}