Skip to content

Commit

Permalink
Relax result type comparison
Browse files Browse the repository at this point in the history
Check if the types are compatible insteal of equal. Sometimes the result
type can be made static and type inference cannot deduce it e.g. when
the k is dynamic
  • Loading branch information
josel-amd committed Aug 21, 2024
1 parent 9d060f9 commit 9939251
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
6 changes: 6 additions & 0 deletions include/xten/Dialect/XTenNN/IR/XTenNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,12 @@ def XtenNN_TopK: XTenNN_Op<"topk", [
let results = (outs AnyRankedTensor:$output, AnyRankedTensor:$indices);

let assemblyFormat = [{ `(`$input `:` type($input) `,` $k `:` type($k)`)` attr-dict `->` type($output) `,` type($indices) }];

let extraClassDeclaration = [{
/// Returns when two result types are compatible for this op; method used by
/// InferTypeOpInterface.
static bool isCompatibleReturnTypes(mlir::TypeRange l, mlir::TypeRange r);
}];
}


Expand Down
6 changes: 6 additions & 0 deletions lib/Dialect/XTenNN/IR/XTenNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,3 +575,9 @@ LogicalResult TopK::inferReturnTypeComponents(
ShapedTypeComponents(resultShape, IntegerType::get(context, 64)));
return success();
}

bool TopK::isCompatibleReturnTypes(mlir::TypeRange l, mlir::TypeRange r) {
if (l.size() != r.size() || l.size() != 2)

This comment has been minimized.

Copy link
@mgehre-amd

mgehre-amd Aug 21, 2024

Collaborator

Why is this not if (l.size() != r.size())?

return false;
return succeeded(verifyCompatibleShapes(l, r));
}

0 comments on commit 9939251

Please sign in to comment.