@@ -545,8 +545,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
545545 /* dtype=*/ noneVal);
546546 return success ();
547547 });
548+ // onnx.ReduceMean with axes provided as argument introduced in opset 18
548549 patterns.onOp (
549- " ReduceMean" , 13 ,
550+ " ReduceMean" , 18 ,
550551 [](OpBinder binder, ConversionPatternRewriter &rewriter) {
551552 Torch::ValueTensorType resultType;
552553 Value data;
@@ -632,6 +633,82 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
632633 /* dtype=*/ noneVal);
633634 return success ();
634635 });
636+
637+ // onnx.ReduceMean with axes provided as attribute
638+ patterns.onOp (
639+ " ReduceMean" , 1 ,
640+ [](OpBinder binder, ConversionPatternRewriter &rewriter) {
641+ Torch::ValueTensorType resultType;
642+ Value data;
643+ llvm::SmallVector<int64_t > axes;
644+ int64_t keepDims;
645+ int64_t noop_with_empty_axes;
646+ if (binder.tensorOperand (data) ||
647+ binder.tensorResultType (resultType) ||
648+ binder.s64IntegerArrayAttr (axes, " axes" , 0 ) ||
649+ binder.s64IntegerAttr (keepDims, " keepdims" , 1 ) ||
650+ binder.s64IntegerAttr (noop_with_empty_axes, " noop_with_empty_axes" ,
651+ 0 ))
652+ return failure ();
653+ SmallVector<Value> dimList;
654+ SmallVector<int64_t > selectSizes;
655+ selectSizes.push_back (1 );
656+ Value noneVal = rewriter.create <Torch::ConstantNoneOp>(binder.getLoc ());
657+ // deal with case when axes is empty
658+ if (axes.size () == 0 ) {
659+ if (noop_with_empty_axes == 0 ) {
660+ Value keepDimsConstInt = rewriter.create <Torch::ConstantIntOp>(
661+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
662+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), keepDims));
663+ Value keepDimsBool = rewriter.create <Torch::AtenBoolIntOp>(
664+ binder.getLoc (), keepDimsConstInt);
665+ rewriter.replaceOpWithNewOp <Torch::AtenMeanDimOp>(
666+ binder.op , resultType, data, /* dim=*/ noneVal, keepDimsBool,
667+ /* dtype=*/ noneVal);
668+ } else {
669+ rewriter.replaceOp (binder.op , data);
670+ }
671+ return success ();
672+ }
673+ Value zero = rewriter.create <Torch::ConstantIntOp>(
674+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
675+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), 0 ));
676+ int64_t adjustmentInt =
677+ cast<Torch::ValueTensorType>(data.getType ()).getSizes ().size ();
678+ Value adjustment = rewriter.create <Torch::ConstantIntOp>(
679+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
680+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ),
681+ adjustmentInt));
682+ // convert axes (tensor) into torch int list while dealing with neg axis
683+ for (int i = 0 ; i < axes.size (); i++) {
684+ // Go through the axes list and get each dim in the list
685+ int64_t dim = axes[i];
686+ if (dim < 0 ) {
687+ dim += adjustmentInt;
688+ }
689+ // deal with neg axis: if (axis < 0) axis += rank
690+ Value finalDim = rewriter.create <Torch::ConstantIntOp>(
691+ binder.getLoc (), rewriter.getType <Torch::IntType>(),
692+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ), dim));
693+ dimList.push_back (finalDim);
694+ }
695+ Value dimValueList = rewriter.create <Torch::PrimListConstructOp>(
696+ binder.getLoc (),
697+ Torch::ListType::get (Torch::IntType::get (binder.op ->getContext ())),
698+ dimList);
699+ Value keepDimBool;
700+ if (keepDims == 1 ) {
701+ keepDimBool =
702+ rewriter.create <Torch::ConstantBoolOp>(binder.getLoc (), true );
703+ } else {
704+ keepDimBool =
705+ rewriter.create <Torch::ConstantBoolOp>(binder.getLoc (), false );
706+ }
707+ rewriter.replaceOpWithNewOp <Torch::AtenMeanDimOp>(
708+ binder.op , resultType, data, dimValueList, keepDimBool,
709+ /* dtype=*/ noneVal);
710+ return success ();
711+ });
635712 patterns.onOp (
636713 " ReduceMin" , 13 ,
637714 [](OpBinder binder, ConversionPatternRewriter &rewriter) {
0 commit comments