Skip to content

Commit a28ff29

Browse files
committed
[AutoBump] Merge with fixes of 7e622b6 (Jan 22)
2 parents cccba22 + 7e622b6 commit a28ff29

File tree

16 files changed

+205
-181
lines changed

16 files changed

+205
-181
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,21 +1581,21 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> {
15811581
Example:
15821582

15831583
```mlir
1584-
%0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
1585-
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<4x9xf32>)
1584+
%0 = tosa.const_shape { value = dense<[1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
1585+
tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<4x9xf32>)
15861586
```
15871587

15881588
Example 2:
15891589

15901590
```mlir
1591-
%0 = arith.constant dense<[-1, 2, 3, 4]> : tensor<4xi32>
1592-
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<4xi32>) -> (tensor<?x9xf32>)
1591+
%0 = tosa.const_shape { value = dense<[-1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
1592+
tosa.pad %arg0, %0 : (tensor<1x2xf32>, !tosa.shape<4>) -> (tensor<?x9xf32>)
15931593
```
15941594
}];
15951595

15961596
let arguments = (ins
15971597
Tosa_RankedTensor:$input1,
1598-
TosaTensorRankOf<[Tosa_Int32Or64], [1]>:$padding,
1598+
Tosa_Shape:$padding,
15991599
Optional<Tosa_ScalarTensor>:$pad_const,
16001600
OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info
16011601
);

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,14 @@ SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
242242
return permuted;
243243
}
244244

245+
// Computes shape value using tosa const_shape op.
246+
Value getTosaConstShape(PatternRewriter &rewriter, Location loc,
247+
llvm::ArrayRef<int64_t> shape);
248+
SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
249+
250+
bool getConstShapeValue(Operation *op,
251+
llvm::SmallVector<int64_t> &result_shape);
252+
245253
} // namespace tosa
246254
} // namespace mlir
247255

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,16 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
306306
ConversionPatternRewriter &rewriter) const final {
307307
auto loc = padOp.getLoc();
308308
auto input = padOp.getInput1();
309-
auto padding = padOp.getPadding();
309+
310+
ElementsAttr paddingElems;
311+
if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
312+
return rewriter.notifyMatchFailure(
313+
padOp, "padding must be a static shape value");
314+
}
315+
llvm::SmallVector<int64_t> paddingVals;
316+
for (auto idx : paddingElems.getValues<IntegerAttr>()) {
317+
paddingVals.push_back(static_cast<int64_t>(idx.getInt()));
318+
}
310319

311320
ShapedType inputTy = cast<ShapedType>(input.getType());
312321
Type elementTy = inputTy.getElementType();
@@ -345,18 +354,10 @@ class PadConverter : public OpConversionPattern<tosa::PadOp> {
345354
highValues.reserve(rank);
346355

347356
for (int i = 0; i < rank; i++) {
348-
Value lowIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i);
349-
Value highIndex = rewriter.create<arith::ConstantIndexOp>(loc, 2 * i + 1);
350-
Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
351-
loc, padding, ValueRange({lowIndex}));
352-
Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
353-
loc, padding, ValueRange({highIndex}));
354-
355-
lowVal = rewriter.createOrFold<arith::IndexCastOp>(
356-
loc, rewriter.getIndexType(), lowVal);
357-
highVal = rewriter.createOrFold<arith::IndexCastOp>(
358-
loc, rewriter.getIndexType(), highVal);
359-
357+
Value lowVal = rewriter.create<arith::ConstantOp>(
358+
loc, rewriter.getIndexAttr(paddingVals[2 * i]));
359+
Value highVal = rewriter.create<arith::ConstantOp>(
360+
loc, rewriter.getIndexAttr(paddingVals[2 * i + 1]));
360361
lowValues.push_back(lowVal);
361362
highValues.push_back(highVal);
362363
}

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ using namespace mlir;
3636
using namespace mlir::tosa;
3737

3838
#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
39+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
3940

4041
//===----------------------------------------------------------------------===//
4142
// Tosa dialect interface includes.
@@ -857,51 +858,42 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
857858
PadOp::Adaptor adaptor,
858859
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
859860
ShapeAdaptor inputShape(adaptor.getInput1().getType());
860-
ShapeAdaptor paddingShape(adaptor.getPadding().getType());
861+
auto paddingRank =
862+
cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
861863
SmallVector<int64_t> outputShape;
862864

863-
// If both inputs have unknown shape, we cannot determine the shape of the
864-
// output.
865-
if (!inputShape.hasRank() && !paddingShape.hasRank()) {
866-
inferredReturnShapes.push_back(ShapedTypeComponents());
867-
return success();
868-
}
869-
870-
// If the input rank is unknown we can info the output rank using the
871-
// padding shape's first dim.
865+
// If the input rank is unknown, we can infer the output rank using the
866+
// padding shape's rank divided by 2.
872867
if (!inputShape.hasRank()) {
873-
if (paddingShape.isDynamicDim(0)) {
874-
inferredReturnShapes.push_back(ShapedTypeComponents());
875-
return success();
876-
}
877-
878-
outputShape.resize(paddingShape.getDimSize(0) / 2, ShapedType::kDynamic);
868+
outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
879869
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
880870
return success();
881871
}
882872

883-
DenseIntElementsAttr paddings;
873+
SmallVector<int64_t> paddingValues;
884874
// If the paddings value is not a constant, all dimensions must be dynamic.
885-
if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
875+
if (!tosa::getConstShapeValue(adaptor.getPadding().getDefiningOp(),
876+
paddingValues)) {
886877
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
887878
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
888879
return success();
889880
}
890881

891-
SmallVector<int64_t> paddingValues;
892-
for (auto val : paddings) {
893-
paddingValues.push_back(val.getSExtValue());
894-
}
895-
896882
outputShape.reserve(inputShape.getRank());
897883
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
898884
if (inputShape.isDynamicDim(i)) {
899885
outputShape.push_back(ShapedType::kDynamic);
900886
continue;
901887
}
888+
auto padFront = paddingValues[i * 2];
889+
auto padBack = paddingValues[i * 2 + 1];
890+
if (padFront < 0 || padBack < 0) {
891+
// if either padding for dim i is -1, output dim is unknown
892+
outputShape.push_back(ShapedType::kDynamic);
893+
continue;
894+
}
902895

903-
outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
904-
paddingValues[i * 2 + 1]);
896+
outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
905897
}
906898

907899
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
@@ -911,17 +903,15 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
911903
LogicalResult tosa::PadOp::verify() {
912904
RankedTensorType inputType = getInput1().getType();
913905
RankedTensorType outputType = getOutput().getType();
914-
RankedTensorType paddingType = getPadding().getType();
906+
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
915907

916908
if (inputType.getRank() != outputType.getRank())
917909
return emitOpError() << "expect same input and output tensor rank.";
918910

919-
if (!paddingType.isDynamicDim(0) &&
920-
paddingType.getDimSize(0) != inputType.getRank() * 2)
911+
if (paddingRank != inputType.getRank() * 2)
921912
return emitOpError() << "expected padding tensor dim 0 to have size "
922913
<< inputType.getRank() * 2
923-
<< " (2*rank(shape1)) but got size "
924-
<< paddingType.getDimSize(0);
914+
<< " (2*rank(shape1)) but got size " << paddingRank;
925915

926916
return success();
927917
}

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
108108
}
109109
}
110110

111-
auto padSizeTy = RankedTensorType::get({10}, rewriter.getI64Type());
112-
auto padSize =
113-
DenseIntElementsAttr::get(padSizeTy, ArrayRef<int64_t>(pad));
114-
Value padSizeVal =
115-
rewriter.create<tosa::ConstOp>(op->getLoc(), padSizeTy, padSize);
111+
Value padSizeVal = getTosaConstShape(rewriter, op->getLoc(), pad);
116112

117113
auto padTy = RankedTensorType::get({}, inputETy);
118114
auto padAttr = DenseElementsAttr::get(padTy, zeroAttr);

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,14 @@ class TransposeConvStridedConverter
135135
int64_t inputChannels = weightTy.getDimSize(3);
136136

137137
// Pad the weight so that it is modulo of the striding.
138-
llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
138+
llvm::SmallVector<int64_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
139139
weightPadding[3] =
140140
(weightHeight % stride[0]) ? (stride[0] - weightHeight % stride[0]) : 0;
141141
weightPadding[5] =
142-
(weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
143-
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
144-
RankedTensorType::get({8}, rewriter.getI32Type()), weightPadding);
145-
Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
146-
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
142+
weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
143+
144+
Value weightPaddingVal =
145+
getTosaConstShape(rewriter, op->getLoc(), weightPadding);
147146

148147
if (op.getQuantizationInfo().has_value()) {
149148
auto quantInfo = op.getQuantizationInfo().value();
@@ -197,17 +196,14 @@ class TransposeConvStridedConverter
197196
/* axis = */ rewriter.getI32IntegerAttr(2));
198197

199198
// We need to pad the input far enough that we can pull all values.
200-
llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
199+
llvm::SmallVector<int64_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
201200
inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
202201
inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
203202
inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
204203
inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
205204

206-
DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
207-
RankedTensorType::get({8}, rewriter.getI32Type()), inputPadding);
208-
209-
Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
210-
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
205+
Value inputPaddingVal =
206+
getTosaConstShape(rewriter, op->getLoc(), inputPadding);
211207

212208
if (op.getQuantizationInfo().has_value()) {
213209
auto quantInfo = op.getQuantizationInfo().value();
@@ -310,17 +306,14 @@ class TransposeConvStridedConverter
310306
rewriter.getDenseI64ArrayAttr(sliceSize))
311307
.getResult();
312308

313-
llvm::SmallVector<int32_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
309+
llvm::SmallVector<int64_t, 8> resultPadding = {0, 0, 0, 0, 0, 0, 0, 0};
314310
resultPadding[2] = resultPadTop;
315311
resultPadding[3] = resultTy.getDimSize(1) - resultPadTop - sliceSize[1];
316312
resultPadding[4] = resultPadLeft;
317313
resultPadding[5] = resultTy.getDimSize(2) - resultPadLeft - sliceSize[2];
318314

319-
DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
320-
RankedTensorType::get({8}, rewriter.getI32Type()), resultPadding);
321-
322-
Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
323-
rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
315+
Value resultPaddingVal =
316+
getTosaConstShape(rewriter, op->getLoc(), resultPadding);
324317

325318
Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
326319
rewriter, loc, UnrankedTensorType::get(resultETy), slice,

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,11 @@ LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
162162
return success();
163163
}
164164

165-
namespace {
166-
SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape) {
165+
SmallVector<int64_t> mlir::tosa::convertFromMlirShape(ArrayRef<int64_t> shape) {
167166
return to_vector(llvm::map_range(shape, [](int64_t dim) {
168167
return ShapedType::isDynamic(dim) ? -1 : dim;
169168
}));
170169
}
171-
} // namespace
172170

173171
Value mlir::tosa::getTosaConstShape(ImplicitLocOpBuilder &builder,
174172
llvm::ArrayRef<int64_t> shape) {
@@ -220,14 +218,32 @@ LogicalResult mlir::tosa ::getConvOpsAccType(PatternRewriter &rewriter,
220218
} else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) &&
221219
outputElemTy.isInteger(48)) {
222220
accType = mlir::TypeAttr::get(rewriter.getIntegerType(48));
223-
} else if ((isa<Float8E4M3FNType>(inputElemTy) && isa<Float8E4M3FNType>(weightElemTy) &&
224-
outputElemTy.isF16()) ||
225-
(isa<Float8E5M2Type>(inputElemTy) && isa<Float8E5M2Type>(weightElemTy) &&
226-
outputElemTy.isF16())) {
221+
} else if ((isa<Float8E4M3FNType>(inputElemTy) &&
222+
isa<Float8E4M3FNType>(weightElemTy) && outputElemTy.isF16()) ||
223+
(isa<Float8E5M2Type>(inputElemTy) &&
224+
isa<Float8E5M2Type>(weightElemTy) && outputElemTy.isF16())) {
227225
accType = mlir::TypeAttr::get(rewriter.getF16Type());
228226
} else {
229227
accType = mlir::TypeAttr::get(outputElemTy);
230228
}
231229

232230
return success();
233231
}
232+
233+
bool mlir::tosa::getConstShapeValue(Operation *op,
234+
llvm::SmallVector<int64_t> &result_shape) {
235+
if (!op) {
236+
return false;
237+
}
238+
if (auto constOp = mlir::dyn_cast<tosa::ConstShapeOp>(op)) {
239+
Attribute constOpAttr = constOp->getAttr("value");
240+
DenseElementsAttr elementsAttr = cast<DenseElementsAttr>(constOpAttr);
241+
for (int i = 0; i < elementsAttr.size(); i++) {
242+
int64_t val = elementsAttr.getValues<int64_t>()[i];
243+
result_shape.push_back(val);
244+
}
245+
return true;
246+
}
247+
// for undefined op, return false.
248+
return false;
249+
}

0 commit comments

Comments
 (0)