From 423e76d3d47468ff8cf153c28d9f0d1a14794c8c Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Wed, 3 Jan 2024 23:24:06 +0530 Subject: [PATCH] [ONNX][MLIR] Add support for onnx.gather op This commit adds support for gather op in the onnx pipeline. Signed-Off-by: Gaurav Shukla --- .../Conversion/TorchOnnxToTorch/Patterns.h | 2 +- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 134 +++++++++++++++++- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 39 +++++ 3 files changed, 169 insertions(+), 6 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index d842ea77bd3cf..072eb94571d5f 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -78,7 +78,7 @@ struct OpBinder { } ParseResult tensorOperandsList( llvm::SmallVectorImpl &values) { - for (int i = 0; i < op->getNumOperands(); i++) { + for (unsigned i = 0; i < op->getNumOperands(); i++) { values.push_back(op->getOperand(i)); } return success(); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index d154edb1ab750..140b8a7294a3d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -182,11 +182,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return failure(); } Value result = operands[0]; - for (int i = 1; i < operands.size(); i++) { - result = rewriter.create( + for (unsigned i = 1; i < operands.size(); i++) { + result = rewriter.create( binder.getLoc(), resultType, result, operands[i]); - } - rewriter.replaceOp( + } + rewriter.replaceOp( binder.op, result.getDefiningOp()); return success(); }); @@ -200,7 +200,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return failure(); } Value result = operands[0]; - for (int i = 1; i < operands.size(); i++) { + for (unsigned i = 1; i < operands.size(); i++) { result = rewriter.create( binder.getLoc(), resultType, result, operands[i]); } @@ -244,6 +244,130 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp( + "Gather", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data, indices; + int64_t axis; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorOperandAtIndex(indices, 1) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(axis, "axis", 0)) + return failure(); + Location loc = binder.getLoc(); + // Get data shape and rank. + auto dataTensorType = data.getType().cast(); + if (!dataTensorType || !dataTensorType.hasSizes()) { + return rewriter.notifyMatchFailure(binder.op, + "Expect non empty input data"); + } + ArrayRef dataShape = dataTensorType.getSizes(); + unsigned dataRank = dataShape.size(); + + // Compute total elements in the indices tensor. + auto indexType = indices.getType().cast(); + if (!indexType || !indexType.hasSizes()) { + return rewriter.notifyMatchFailure(binder.op, + "Expect non empty index tensor"); + } + ArrayRef indexShape = indexType.getSizes(); + unsigned indexRank = indexShape.size(); + int64_t indexElemCount = 1; + for (int64_t dim : indexShape) { + if (dim == -1) { + indexElemCount = Torch::kUnknownSize; + break; + } + indexElemCount *= dim; + } + + // We collapse indices into a (`indexElemCount`,) unary tensor, materialize all the non-axis dimension wrt data shape. + SmallVector collapsedIndexShape(dataRank, 1); + collapsedIndexShape[axis] = Torch::kUnknownSize; + if (indexElemCount != -1) + collapsedIndexShape[axis] = indexElemCount; + + Value constOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + SmallVector indexShapeTensor; + Value prod = constOne; + for (unsigned i = 0; i < indexRank; ++i) { + Value indexDimVal = rewriter.create( + loc, indices, + rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + indexShapeTensor.emplace_back(indexDimVal); + prod = rewriter.create(loc, prod, indexDimVal); + } + + SmallVector collapsedIndexSize(dataRank, constOne); + collapsedIndexSize[axis] = prod; + auto collapsedIndexSizeList = + rewriter.create( + loc, + Torch::ListType::get(Torch::IntType::get(indices.getContext())), + collapsedIndexSize); + + Type collapsedIndexType = Torch::ValueTensorType::get( + indexType.getContext(), llvm::ArrayRef(collapsedIndexShape), + indexType.getOptionalDtype()); + auto collapsedIndices = rewriter.create( + loc, collapsedIndexType, indices, collapsedIndexSizeList); + + Type gatherResultType = Torch::ValueTensorType::get( + dataTensorType.getContext(), llvm::ArrayRef(collapsedIndexShape), + dataTensorType.getOptionalDtype()); + Value constAxis = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); + Value constFalse = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getBoolAttr(false)); + auto gatherOp = rewriter.create( + loc, gatherResultType, data, constAxis, collapsedIndices, + /*sparseGrad=*/constFalse); + + SmallVector dataShapeVector(dataShape); + dataShapeVector[axis] = Torch::kUnknownSize; + if (indexElemCount != -1) + dataShapeVector[axis] = indexElemCount; + Type expandResultType = Torch::ValueTensorType::get( + dataTensorType.getContext(), llvm::ArrayRef(dataShapeVector), + dataTensorType.getOptionalDtype()); + SmallVector dataShapeTensor; + for (unsigned i = 0; i < dataRank; ++i) { + dataShapeTensor.emplace_back(rewriter.create( + loc, data, + rewriter.create( + loc, rewriter.getI64IntegerAttr(i)))); + } + dataShapeTensor[axis] = prod; + auto expandSizeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(data.getContext())), + dataShapeTensor); + auto expandedGather = rewriter.create( + loc, expandResultType, gatherOp, expandSizeList, + /*implicit=*/constFalse); + + // Create result size list for the aten.view op. + SmallVector resultShapeTensor; + for (unsigned i = 0; i < dataRank; ++i) { + if (i == axis) { + resultShapeTensor.insert(resultShapeTensor.end(), + indexShapeTensor.begin(), + indexShapeTensor.end()); + continue; + } + resultShapeTensor.emplace_back(dataShapeTensor[i]); + } + auto resultSizeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(data.getContext())), + resultShapeTensor); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, expandedGather, resultSizeList); + return success(); + }); patterns.onOp( "GatherElements", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index e224ddfa2944c..7bf2654eb789a 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -31,6 +31,45 @@ func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[ return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + +// CHECK-LABEL: func.func @test_gather +func.func @test_gather(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[ARG1_SIZE0:.+]] = torch.aten.size.int %arg1, %[[INT0]] + // CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT1]], %[[ARG1_SIZE0]] + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK: %[[ARG1_SIZE1:.+]] = torch.aten.size.int %arg1, %[[INT1_0]] + // CHECK: %[[MUL2:.+]] = torch.aten.mul.int %[[MUL1]], %[[ARG1_SIZE1]] + // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[ARG1_SIZE2:.+]] = torch.aten.size.int %arg1, %[[INT2]] : !torch.vtensor<[8,10,20,40],si64>, !torch.int -> !torch.int + // CHECK: %[[MUL3:.+]] = torch.aten.mul.int %[[MUL2]], %[[ARG1_SIZE2]] : !torch.int, !torch.int -> !torch.int + // CHECK-DAG: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[ARG1_SIZE3:.+]] = torch.aten.size.int %arg1, %[[INT3]] : !torch.vtensor<[8,10,20,40],si64>, !torch.int -> !torch.int + // CHECK: %[[MUL4:.+]] = torch.aten.mul.int %[[MUL3]], %[[ARG1_SIZE3]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST1:.+]] = torch.prim.ListConstruct %[[MUL4]], %[[INT1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[VIEW1:.+]] = torch.aten.view %arg1, %[[LIST1]] : !torch.vtensor<[8,10,20,40],si64>, !torch.list -> !torch.vtensor<[64000,1,1],si64> + // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT0_1]], %[[VIEW1]], %[[FALSE]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.vtensor<[64000,1,1],si64>, !torch.bool -> !torch.vtensor<[64000,1,1],f32> + // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 + // CHECK: %[[ARG0_SIZE0:.+]] = torch.aten.size.int %arg0, %[[INT0_2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int + // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 + // CHECK: %[[ARG0_SIZE1:.+]] = torch.aten.size.int %arg0, %[[INT1_3]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int + // CHECK: %[[INT2_4:.+]] = torch.constant.int 2 + // CHECK: %[[ARG0_SIZE2:.+]] = torch.aten.size.int %arg0, %[[INT2_4]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int + // CHECK: %[[LIST2:.+]] = torch.prim.ListConstruct %[[MUL4]], %[[ARG0_SIZE1]], %[[ARG0_SIZE2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[GATHER]], %[[LIST2]], %[[FALSE]] : !torch.vtensor<[64000,1,1],f32>, !torch.list, !torch.bool -> !torch.vtensor<[64000,4,5],f32> + // CHECK: %[[LIST3:.+]] = torch.prim.ListConstruct %[[ARG1_SIZE0]], %[[ARG1_SIZE1]], %[[ARG1_SIZE2]], %[[ARG1_SIZE3]], %[[ARG0_SIZE1]], %[[ARG0_SIZE2]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RES:.+]] = torch.aten.view %[[EXPAND]], %[[LIST3]] : !torch.vtensor<[64000,4,5],f32>, !torch.list -> !torch.vtensor<[8,10,20,40,4,5],f32> + // CHECK: return %[[RES]] : !torch.vtensor<[8,10,20,40,4,5],f32> + %0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> + return %0 : !torch.vtensor<[8,10,20,40,4,5],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_gather_elements func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0