diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index c0e9af22ceb9..4756315c800e 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -545,8 +545,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*dtype=*/noneVal); return success(); }); + // onnx.ReduceMean with axes provided as argument introduced in opset 18 patterns.onOp( - "ReduceMean", 13, + "ReduceMean", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data; @@ -632,6 +633,82 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*dtype=*/noneVal); return success(); }); + + // onnx.ReduceMean with axes provided as attribute + patterns.onOp( + "ReduceMean", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + llvm::SmallVector axes; + int64_t keepDims; + int64_t noop_with_empty_axes; + if (binder.tensorOperand(data) || + binder.tensorResultType(resultType) || + binder.s64IntegerArrayAttr(axes, "axes", 0) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Value noneVal = rewriter.create(binder.getLoc()); + // deal with case when axes is empty + if (axes.size() == 0) { + if (noop_with_empty_axes == 0) { + Value keepDimsConstInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); + Value keepDimsBool = rewriter.create( + binder.getLoc(), keepDimsConstInt); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, /*dim=*/noneVal, keepDimsBool, + /*dtype=*/noneVal); + } else { + rewriter.replaceOp(binder.op, data); + } + return success(); + } + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + int64_t adjustmentInt = + cast(data.getType()).getSizes().size(); + Value adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + adjustmentInt)); + // convert axes (tensor) into torch int list while dealing with neg axis + for (int i = 0; i < axes.size(); i++) { + // Go through the axes list and get each dim in the list + int64_t dim = axes[i]; + if (dim < 0) { + dim += adjustmentInt; + } + // deal with neg axis: if (axis < 0) axis += rank + Value finalDim = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), dim)); + dimList.push_back(finalDim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + Value keepDimBool; + if (keepDims == 1) { + keepDimBool = + rewriter.create(binder.getLoc(), true); + } else { + keepDimBool = + rewriter.create(binder.getLoc(), false); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool, + /*dtype=*/noneVal); + return success(); + }); patterns.onOp( "ReduceMin", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir index 8401c378b77c..3ed9f1c6ebe6 100644 --- a/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir +++ b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir @@ -34,7 +34,6 @@ func.func @equal_operation(%arg0: !torch.vtensor<[4],si64>, func.func @reduce_mean_operation(%arg0: !torch.vtensor<[1,64,768],f32>) -> !torch.vtensor<[1,64,1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // The ReduceMean operation as provided. - // expected-error @+1 {{failed to legalize operation 'torch.operator'}} %211 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.axes = [-1 : si64]} : (!torch.vtensor<[1,64,768],f32>) -> !torch.vtensor<[1,64,1],f32> return %211 : !torch.vtensor<[1,64,1],f32> } \ No newline at end of file