Skip to content

Commit d10dca6

Browse files
authored
[mlir][Vector] Move vector.insert canonicalizers for DenseElementsAttr to folders (#128040)
This PR moves vector.insert canonicalizers for DenseElementsAttr (splat and non splat case) to folders. Folders are local, and it's always better to implement a folder than a canonicalizer. This PR is mostly NFC-ish, because the functionality mostly remains same, but is now run as part of a folder, which is why some tests are changed, because GreedyPatternRewriter tries to fold by default.
1 parent 6701669 commit d10dca6

File tree

3 files changed

+63
-83
lines changed

3 files changed

+63
-83
lines changed

Diff for: mlir/lib/Dialect/Vector/IR/VectorOps.cpp

+54-65
Original file line numberDiff line numberDiff line change
@@ -3019,94 +3019,78 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
30193019
}
30203020
};
30213021

3022-
// Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
3023-
class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
3024-
public:
3025-
using OpRewritePattern::OpRewritePattern;
3026-
3027-
// Do not create constants with more than `vectorSizeFoldThreashold` elements,
3028-
// unless the source vector constant has a single use.
3029-
static constexpr int64_t vectorSizeFoldThreshold = 256;
3030-
3031-
LogicalResult matchAndRewrite(InsertOp op,
3032-
PatternRewriter &rewriter) const override {
3033-
// TODO: Canonicalization for dynamic position not implemented yet.
3034-
if (op.hasDynamicPosition())
3035-
return failure();
3022+
} // namespace
30363023

3037-
// Return if 'InsertOp' operand is not defined by a compatible vector
3038-
// ConstantOp.
3039-
TypedValue<VectorType> destVector = op.getDest();
3040-
Attribute vectorDestCst;
3041-
if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
3042-
return failure();
3043-
auto denseDest = llvm::dyn_cast<DenseElementsAttr>(vectorDestCst);
3044-
if (!denseDest)
3045-
return failure();
3024+
static Attribute
3025+
foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
3026+
Attribute dstAttr,
3027+
int64_t maxVectorSizeFoldThreshold) {
3028+
if (insertOp.hasDynamicPosition())
3029+
return {};
30463030

3047-
VectorType destTy = destVector.getType();
3048-
if (destTy.isScalable())
3049-
return failure();
3031+
auto denseDst = llvm::dyn_cast_if_present<DenseElementsAttr>(dstAttr);
3032+
if (!denseDst)
3033+
return {};
30503034

3051-
// Make sure we do not create too many large constants.
3052-
if (destTy.getNumElements() > vectorSizeFoldThreshold &&
3053-
!destVector.hasOneUse())
3054-
return failure();
3035+
if (!srcAttr) {
3036+
return {};
3037+
}
30553038

3056-
Value sourceValue = op.getSource();
3057-
Attribute sourceCst;
3058-
if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
3059-
return failure();
3039+
VectorType destTy = insertOp.getDestVectorType();
3040+
if (destTy.isScalable())
3041+
return {};
30603042

3061-
// Calculate the linearized position of the continuous chunk of elements to
3062-
// insert.
3063-
llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3064-
copy(op.getStaticPosition(), completePositions.begin());
3065-
int64_t insertBeginPosition =
3066-
linearize(completePositions, computeStrides(destTy.getShape()));
3067-
3068-
SmallVector<Attribute> insertedValues;
3069-
Type destEltType = destTy.getElementType();
3070-
3071-
// The `convertIntegerAttr` method specifically handles the case
3072-
// for `llvm.mlir.constant` which can hold an attribute with a
3073-
// different type than the return type.
3074-
if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(sourceCst)) {
3075-
for (auto value : denseSource.getValues<Attribute>())
3076-
insertedValues.push_back(convertIntegerAttr(value, destEltType));
3077-
} else {
3078-
insertedValues.push_back(convertIntegerAttr(sourceCst, destEltType));
3079-
}
3043+
// Make sure we do not create too many large constants.
3044+
if (destTy.getNumElements() > maxVectorSizeFoldThreshold &&
3045+
!insertOp->hasOneUse())
3046+
return {};
30803047

3081-
auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
3082-
copy(insertedValues, allValues.begin() + insertBeginPosition);
3083-
auto newAttr = DenseElementsAttr::get(destTy, allValues);
3048+
// Calculate the linearized position of the continuous chunk of elements to
3049+
// insert.
3050+
llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
3051+
copy(insertOp.getStaticPosition(), completePositions.begin());
3052+
int64_t insertBeginPosition =
3053+
linearize(completePositions, computeStrides(destTy.getShape()));
30843054

3085-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
3086-
return success();
3087-
}
3055+
SmallVector<Attribute> insertedValues;
3056+
Type destEltType = destTy.getElementType();
30883057

3089-
private:
30903058
/// Converts the expected type to an IntegerAttr if there's
30913059
/// a mismatch.
3092-
Attribute convertIntegerAttr(Attribute attr, Type expectedType) const {
3060+
auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
30933061
if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
30943062
if (intAttr.getType() != expectedType)
30953063
return IntegerAttr::get(expectedType, intAttr.getInt());
30963064
}
30973065
return attr;
3066+
};
3067+
3068+
// The `convertIntegerAttr` method specifically handles the case
3069+
// for `llvm.mlir.constant` which can hold an attribute with a
3070+
// different type than the return type.
3071+
if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
3072+
for (auto value : denseSource.getValues<Attribute>())
3073+
insertedValues.push_back(convertIntegerAttr(value, destEltType));
3074+
} else {
3075+
insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
30983076
}
3099-
};
31003077

3101-
} // namespace
3078+
auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
3079+
copy(insertedValues, allValues.begin() + insertBeginPosition);
3080+
auto newAttr = DenseElementsAttr::get(destTy, allValues);
3081+
3082+
return newAttr;
3083+
}
31023084

31033085
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
31043086
MLIRContext *context) {
3105-
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
3106-
InsertOpConstantFolder>(context);
3087+
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
31073088
}
31083089

31093090
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
3091+
// Do not create constants with more than `vectorSizeFoldThreashold` elements,
3092+
// unless the source vector constant has a single use.
3093+
constexpr int64_t vectorSizeFoldThreshold = 256;
31103094
// Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
31113095
// %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
31123096
// (type mismatch).
@@ -3118,6 +3102,11 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
31183102
if (auto res = foldPoisonIndexInsertExtractOp(
31193103
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
31203104
return res;
3105+
if (auto res = foldDenseElementsAttrDestInsertOp(*this, adaptor.getSource(),
3106+
adaptor.getDest(),
3107+
vectorSizeFoldThreshold)) {
3108+
return res;
3109+
}
31213110

31223111
return {};
31233112
}

Diff for: mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

+3-7
Original file line numberDiff line numberDiff line change
@@ -1517,13 +1517,9 @@ func.func @constant_mask_2d() -> vector<4x4xi1> {
15171517
}
15181518

15191519
// CHECK-LABEL: func @constant_mask_2d
1520-
// CHECK: %[[VAL_0:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
1521-
// CHECK: %[[VAL_1:.*]] = arith.constant dense<false> : vector<4x4xi1>
1522-
// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[VAL_1]] : vector<4x4xi1> to !llvm.array<4 x vector<4xi1>>
1523-
// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][0] : !llvm.array<4 x vector<4xi1>>
1524-
// CHECK: %[[VAL_4:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_3]][1] : !llvm.array<4 x vector<4xi1>>
1525-
// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !llvm.array<4 x vector<4xi1>> to vector<4x4xi1>
1526-
// CHECK: return %[[VAL_5]] : vector<4x4xi1>
1520+
// CHECK: %[[VAL_0:.*]] = arith.constant
1521+
// CHECK-SAME{LITERAL}: dense<[[true, true, false, false], [true, true, false, false], [false, false, false, false], [false, false, false, false]]> : vector<4x4xi1>
1522+
// CHECK: return %[[VAL_0]] : vector<4x4xi1>
15271523

15281524
// -----
15291525

Diff for: mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir

+6-11
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,19 @@ func.func @genbool_1d() -> vector<8xi1> {
1010
}
1111

1212
// CHECK-LABEL: func @genbool_2d
13-
// CHECK: %[[C1:.*]] = arith.constant dense<[true, true, false, false]> : vector<4xi1>
14-
// CHECK: %[[C2:.*]] = arith.constant dense<false> : vector<4x4xi1>
15-
// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<4x4xi1>
16-
// CHECK: %[[T1:.*]] = vector.insert %[[C1]], %[[T0]] [1] : vector<4xi1> into vector<4x4xi1>
17-
// CHECK: return %[[T1]] : vector<4x4xi1>
13+
// CHECK: %[[C0:.*]] = arith.constant
14+
// CHECK-SAME{LITERAL}: dense<[[true, true, false, false], [true, true, false, false], [false, false, false, false], [false, false, false, false]]> : vector<4x4xi1>
15+
// CHECK: return %[[C0]] : vector<4x4xi1>
1816

1917
func.func @genbool_2d() -> vector<4x4xi1> {
2018
%v = vector.constant_mask [2, 2] : vector<4x4xi1>
2119
return %v: vector<4x4xi1>
2220
}
2321

2422
// CHECK-LABEL: func @genbool_3d
25-
// CHECK-DAG: %[[C1:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
26-
// CHECK-DAG: %[[C2:.*]] = arith.constant dense<false> : vector<3x4xi1>
27-
// CHECK-DAG: %[[C3:.*]] = arith.constant dense<false> : vector<2x3x4xi1>
28-
// CHECK: %[[T0:.*]] = vector.insert %[[C1]], %[[C2]] [0] : vector<4xi1> into vector<3x4xi1>
29-
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C3]] [0] : vector<3x4xi1> into vector<2x3x4xi1>
30-
// CHECK: return %[[T1]] : vector<2x3x4xi1>
23+
// CHECK: %[[C0:.*]] = arith.constant
24+
// CHECK-SAME{LITERAL}: dense<[[[true, true, true, false], [false, false, false, false], [false, false, false, false]], [[false, false, false, false], [false, false, false, false], [false, false, false, false]]]> : vector<2x3x4xi1>
25+
// CHECK: return %[[C0]] : vector<2x3x4xi1>
3126

3227
func.func @genbool_3d() -> vector<2x3x4xi1> {
3328
%v = vector.constant_mask [1, 1, 3] : vector<2x3x4xi1>

0 commit comments

Comments
 (0)