Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX][MLIR] Add support for onnx.gather op #2726

Merged
merged 1 commit into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 152 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
for (uint64_t i = 1; i < operands.size(); i++) {
result = rewriter.create<Torch::AtenMaximumOp>(
binder.getLoc(), resultType, result, operands[i]);
}
rewriter.replaceOp(
binder.op, result.getDefiningOp());
return success();
}
rewriter.replaceOp(binder.op, result.getDefiningOp());
return success();
});
patterns.onOp("Min", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Expand Down Expand Up @@ -323,6 +322,155 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp(
"Gather", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value data, indices;
int64_t axis;
if (binder.tensorOperandAtIndex(data, 0) ||
binder.tensorOperandAtIndex(indices, 1) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(axis, "axis", 0))
return failure();
Location loc = binder.getLoc();

// 1. Get data shape and rank.
auto dataTensorType = data.getType().cast<Torch::ValueTensorType>();
if (!dataTensorType || !dataTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(binder.op,
"Expect non empty input data");
}
ArrayRef<int64_t> dataShape = dataTensorType.getSizes();
unsigned dataRank = dataShape.size();

// 2. Get indices shape and rank.
auto indexType = indices.getType().cast<Torch::ValueTensorType>();
if (!indexType || !indexType.hasSizes()) {
return rewriter.notifyMatchFailure(binder.op,
"Expect non empty index tensor");
}
ArrayRef<int64_t> indexShape = indexType.getSizes();
unsigned indexRank = indexShape.size();

// 3. Compute total elements in the indices tensor, as we will collapse
// the indices tensor to a unary tensor. Also compute index shape and
// data shape tensors as they will be used for creating output types.
int64_t indexElemCount = 1;
for (int64_t dim : indexShape) {
if (dim == Torch::kUnknownSize) {
indexElemCount = Torch::kUnknownSize;
break;
}
indexElemCount *= dim;
}

Value constOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
SmallVector<Value> indexShapeTensor;
Value indexElemCountVal = constOne;
for (unsigned i = 0; i < indexRank; ++i) {
Value indexDimVal = rewriter.create<Torch::AtenSizeIntOp>(
loc, indices,
rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
indexShapeTensor.emplace_back(indexDimVal);
indexElemCountVal = rewriter.create<Torch::AtenMulIntOp>(
loc, indexElemCountVal, indexDimVal);
}

SmallVector<Value> dataShapeTensor;
for (unsigned i = 0; i < dataRank; ++i) {
dataShapeTensor.emplace_back(rewriter.create<Torch::AtenSizeIntOp>(
loc, data,
rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i))));
}

// 4. We can not directly perform torch.gather as the onnx.gather op
// collects the input data at different location of output compared to
// torch.gather op. The output of torch.gather and onnx.gather ops are
// indexed differently.
// check https://onnx.ai/onnx/operators/onnx__Gather.html for more
// details. So we will collapse indices tensor to a unary tensor and
// materialize to non-axis dimension of data tensor. For example,
// assuming indices is of shape (4, 5, 6), data is (8, 10, 11, 12) and
// axis=1. we will collapse indices into a (120,) unary tensor,
// materialize to non-axis dimension of data i.e. reshaping the unary
// indices tensor to (1, 120, 1, 1) and then perform the torch.gather
// operation. Now broadcast the output of gather operation to non-axis
// dimensions of data tensor. This would make the result of shape (8,
// 10, 120, 12). Post the broadcasting, expand the indices dimensions by
// reshaping (8, 10, 120, 12) to (8, 10, 4, 5, 6, 12) tensor, which is
// our expected final result.
SmallVector<int64_t> collapsedIndexShape(dataRank, 1);
collapsedIndexShape[axis] = indexElemCount;
Type collapsedIndexType = Torch::ValueTensorType::get(
indexType.getContext(), llvm::ArrayRef(collapsedIndexShape),
indexType.getOptionalDtype());

SmallVector<Value> collapsedIndexSize(dataRank, constOne);
collapsedIndexSize[axis] = indexElemCountVal;
auto collapsedIndexSizeList =
rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
collapsedIndexSize);

auto collapsedIndices = rewriter.create<Torch::AtenViewOp>(
loc, collapsedIndexType, indices, collapsedIndexSizeList);

// 5. Compute gather result type and perform gather operation.
Type gatherResultType = Torch::ValueTensorType::get(
dataTensorType.getContext(), llvm::ArrayRef(collapsedIndexShape),
dataTensorType.getOptionalDtype());
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis));
Value constFalse = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getType<Torch::BoolType>(),
rewriter.getBoolAttr(false));
auto gatherOp = rewriter.create<Torch::AtenGatherOp>(
loc, gatherResultType, data, constAxis, collapsedIndices,
/*sparseGrad=*/constFalse);

// 6. Broadcast the gather output to non-axis dimensions of data tensor.
SmallVector<int64_t> dataShapeVector(dataShape);
dataShapeVector[axis] = indexElemCount;
Type expandResultType = Torch::ValueTensorType::get(
dataTensorType.getContext(), llvm::ArrayRef(dataShapeVector),
dataTensorType.getOptionalDtype());

dataShapeTensor[axis] = indexElemCountVal;
auto expandSizeList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(data.getContext())),
dataShapeTensor);
auto expandedGather = rewriter.create<Torch::AtenExpandOp>(
loc, expandResultType, gatherOp, expandSizeList,
/*implicit=*/constFalse);

// 7. Compute the result type of reshape op which expands the collapsed
// indices shapes back to the original indices shapes and reshape the
// output produced at step 6. This will produce our expected result of
// onnx.gather op.
SmallVector<Value> resultShapeTensor;
for (unsigned i = 0; i < dataRank; ++i) {
if (i == axis) {
resultShapeTensor.insert(resultShapeTensor.end(),
indexShapeTensor.begin(),
indexShapeTensor.end());
continue;
}
resultShapeTensor.emplace_back(dataShapeTensor[i]);
}
auto resultSizeList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(data.getContext())),
resultShapeTensor);

rewriter.replaceOpWithNewOp<Torch::AtenViewOp>(
binder.op, resultType, expandedGather, resultSizeList);
return success();
});
patterns.onOp(
"GatherElements", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Expand Down
37 changes: 37 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,43 @@ func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[

// -----

// CHECK-LABEL: func.func @test_gather
func.func @test_gather(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
// CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[ARG1_SIZE0:.+]] = torch.aten.size.int %arg1, %[[INT0]]
// CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT1]], %[[ARG1_SIZE0]]
// CHECK: %[[INT1_0:.+]] = torch.constant.int 1
// CHECK: %[[ARG1_SIZE1:.+]] = torch.aten.size.int %arg1, %[[INT1_0]]
// CHECK: %[[MUL2:.+]] = torch.aten.mul.int %[[MUL1]], %[[ARG1_SIZE1]]
// CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2
// CHECK: %[[ARG1_SIZE2:.+]] = torch.aten.size.int %arg1, %[[INT2]] : !torch.vtensor<[8,10,20,40],si64>, !torch.int -> !torch.int
// CHECK: %[[MUL3:.+]] = torch.aten.mul.int %[[MUL2]], %[[ARG1_SIZE2]] : !torch.int, !torch.int -> !torch.int
// CHECK-DAG: %[[INT3:.+]] = torch.constant.int 3
// CHECK: %[[ARG1_SIZE3:.+]] = torch.aten.size.int %arg1, %[[INT3]] : !torch.vtensor<[8,10,20,40],si64>, !torch.int -> !torch.int
// CHECK: %[[MUL4:.+]] = torch.aten.mul.int %[[MUL3]], %[[ARG1_SIZE3]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[INT0_2:.+]] = torch.constant.int 0
// CHECK: %[[ARG0_SIZE0:.+]] = torch.aten.size.int %arg0, %[[INT0_2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int
// CHECK: %[[INT1_3:.+]] = torch.constant.int 1
// CHECK: %[[ARG0_SIZE1:.+]] = torch.aten.size.int %arg0, %[[INT1_3]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int
// CHECK: %[[INT2_4:.+]] = torch.constant.int 2
// CHECK: %[[ARG0_SIZE2:.+]] = torch.aten.size.int %arg0, %[[INT2_4]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int
// CHECK: %[[LIST1:.+]] = torch.prim.ListConstruct %[[MUL4]], %[[INT1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VIEW1:.+]] = torch.aten.view %arg1, %[[LIST1]] : !torch.vtensor<[8,10,20,40],si64>, !torch.list<int> -> !torch.vtensor<[64000,1,1],si64>
// CHECK: %[[INT0_1:.+]] = torch.constant.int 0
// CHECK: %[[FALSE:.+]] = torch.constant.bool false
// CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT0_1]], %[[VIEW1]], %[[FALSE]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.vtensor<[64000,1,1],si64>, !torch.bool -> !torch.vtensor<[64000,1,1],f32>
// CHECK: %[[LIST2:.+]] = torch.prim.ListConstruct %[[MUL4]], %[[ARG0_SIZE1]], %[[ARG0_SIZE2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[GATHER]], %[[LIST2]], %[[FALSE]] : !torch.vtensor<[64000,1,1],f32>, !torch.list<int>, !torch.bool -> !torch.vtensor<[64000,4,5],f32>
// CHECK: %[[LIST3:.+]] = torch.prim.ListConstruct %[[ARG1_SIZE0]], %[[ARG1_SIZE1]], %[[ARG1_SIZE2]], %[[ARG1_SIZE3]], %[[ARG0_SIZE1]], %[[ARG0_SIZE2]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RES:.+]] = torch.aten.view %[[EXPAND]], %[[LIST3]] : !torch.vtensor<[64000,4,5],f32>, !torch.list<int> -> !torch.vtensor<[8,10,20,40,4,5],f32>
// CHECK: return %[[RES]] : !torch.vtensor<[8,10,20,40,4,5],f32>
%0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32>
return %0 : !torch.vtensor<[8,10,20,40,4,5],f32>
}

// -----

// CHECK-LABEL: func.func @test_gather_elements
func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} {
// CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0
Expand Down