@@ -3019,94 +3019,78 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3019
3019
}
3020
3020
};
3021
3021
3022
- // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
3023
- class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
3024
- public:
3025
- using OpRewritePattern::OpRewritePattern;
3026
-
3027
- // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3028
- // unless the source vector constant has a single use.
3029
- static constexpr int64_t vectorSizeFoldThreshold = 256 ;
3030
-
3031
- LogicalResult matchAndRewrite (InsertOp op,
3032
- PatternRewriter &rewriter) const override {
3033
- // TODO: Canonicalization for dynamic position not implemented yet.
3034
- if (op.hasDynamicPosition ())
3035
- return failure ();
3022
+ } // namespace
3036
3023
3037
- // Return if 'InsertOp' operand is not defined by a compatible vector
3038
- // ConstantOp.
3039
- TypedValue<VectorType> destVector = op.getDest ();
3040
- Attribute vectorDestCst;
3041
- if (!matchPattern (destVector, m_Constant (&vectorDestCst)))
3042
- return failure ();
3043
- auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
3044
- if (!denseDest)
3045
- return failure ();
3024
+ static Attribute
3025
+ foldDenseElementsAttrDestInsertOp (InsertOp insertOp, Attribute srcAttr,
3026
+ Attribute dstAttr,
3027
+ int64_t maxVectorSizeFoldThreshold) {
3028
+ if (insertOp.hasDynamicPosition ())
3029
+ return {};
3046
3030
3047
- VectorType destTy = destVector. getType ( );
3048
- if (destTy. isScalable () )
3049
- return failure () ;
3031
+ auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr );
3032
+ if (!denseDst )
3033
+ return {} ;
3050
3034
3051
- // Make sure we do not create too many large constants.
3052
- if (destTy.getNumElements () > vectorSizeFoldThreshold &&
3053
- !destVector.hasOneUse ())
3054
- return failure ();
3035
+ if (!srcAttr) {
3036
+ return {};
3037
+ }
3055
3038
3056
- Value sourceValue = op.getSource ();
3057
- Attribute sourceCst;
3058
- if (!matchPattern (sourceValue, m_Constant (&sourceCst)))
3059
- return failure ();
3039
+ VectorType destTy = insertOp.getDestVectorType ();
3040
+ if (destTy.isScalable ())
3041
+ return {};
3060
3042
3061
- // Calculate the linearized position of the continuous chunk of elements to
3062
- // insert.
3063
- llvm::SmallVector<int64_t > completePositions (destTy.getRank (), 0 );
3064
- copy (op.getStaticPosition (), completePositions.begin ());
3065
- int64_t insertBeginPosition =
3066
- linearize (completePositions, computeStrides (destTy.getShape ()));
3067
-
3068
- SmallVector<Attribute> insertedValues;
3069
- Type destEltType = destTy.getElementType ();
3070
-
3071
- // The `convertIntegerAttr` method specifically handles the case
3072
- // for `llvm.mlir.constant` which can hold an attribute with a
3073
- // different type than the return type.
3074
- if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
3075
- for (auto value : denseSource.getValues <Attribute>())
3076
- insertedValues.push_back (convertIntegerAttr (value, destEltType));
3077
- } else {
3078
- insertedValues.push_back (convertIntegerAttr (sourceCst, destEltType));
3079
- }
3043
+ // Make sure we do not create too many large constants.
3044
+ if (destTy.getNumElements () > maxVectorSizeFoldThreshold &&
3045
+ !insertOp->hasOneUse ())
3046
+ return {};
3080
3047
3081
- auto allValues = llvm::to_vector (denseDest.getValues <Attribute>());
3082
- copy (insertedValues, allValues.begin () + insertBeginPosition);
3083
- auto newAttr = DenseElementsAttr::get (destTy, allValues);
3048
+ // Calculate the linearized position of the continuous chunk of elements to
3049
+ // insert.
3050
+ llvm::SmallVector<int64_t > completePositions (destTy.getRank (), 0 );
3051
+ copy (insertOp.getStaticPosition (), completePositions.begin ());
3052
+ int64_t insertBeginPosition =
3053
+ linearize (completePositions, computeStrides (destTy.getShape ()));
3084
3054
3085
- rewriter.replaceOpWithNewOp <arith::ConstantOp>(op, newAttr);
3086
- return success ();
3087
- }
3055
+ SmallVector<Attribute> insertedValues;
3056
+ Type destEltType = destTy.getElementType ();
3088
3057
3089
- private:
3090
3058
// / Converts the expected type to an IntegerAttr if there's
3091
3059
// / a mismatch.
3092
- Attribute convertIntegerAttr (Attribute attr, Type expectedType) const {
3060
+ auto convertIntegerAttr = [] (Attribute attr, Type expectedType) -> Attribute {
3093
3061
if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
3094
3062
if (intAttr.getType () != expectedType)
3095
3063
return IntegerAttr::get (expectedType, intAttr.getInt ());
3096
3064
}
3097
3065
return attr;
3066
+ };
3067
+
3068
+ // The `convertIntegerAttr` method specifically handles the case
3069
+ // for `llvm.mlir.constant` which can hold an attribute with a
3070
+ // different type than the return type.
3071
+ if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3072
+ for (auto value : denseSource.getValues <Attribute>())
3073
+ insertedValues.push_back (convertIntegerAttr (value, destEltType));
3074
+ } else {
3075
+ insertedValues.push_back (convertIntegerAttr (srcAttr, destEltType));
3098
3076
}
3099
- };
3100
3077
3101
- } // namespace
3078
+ auto allValues = llvm::to_vector (denseDst.getValues <Attribute>());
3079
+ copy (insertedValues, allValues.begin () + insertBeginPosition);
3080
+ auto newAttr = DenseElementsAttr::get (destTy, allValues);
3081
+
3082
+ return newAttr;
3083
+ }
3102
3084
3103
3085
void InsertOp::getCanonicalizationPatterns (RewritePatternSet &results,
3104
3086
MLIRContext *context) {
3105
- results.add <InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3106
- InsertOpConstantFolder>(context);
3087
+ results.add <InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
3107
3088
}
3108
3089
3109
3090
OpFoldResult vector::InsertOp::fold (FoldAdaptor adaptor) {
3091
+ // Do not create constants with more than `vectorSizeFoldThreashold` elements,
3092
+ // unless the source vector constant has a single use.
3093
+ constexpr int64_t vectorSizeFoldThreshold = 256 ;
3110
3094
// Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
3111
3095
// %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
3112
3096
// (type mismatch).
@@ -3118,6 +3102,11 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3118
3102
if (auto res = foldPoisonIndexInsertExtractOp (
3119
3103
getContext (), adaptor.getStaticPosition (), kPoisonIndex ))
3120
3104
return res;
3105
+ if (auto res = foldDenseElementsAttrDestInsertOp (*this , adaptor.getSource (),
3106
+ adaptor.getDest (),
3107
+ vectorSizeFoldThreshold)) {
3108
+ return res;
3109
+ }
3121
3110
3122
3111
return {};
3123
3112
}
0 commit comments