-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][tosa] Support RescaleOp with dynamic extension in TosaToLinalg #155967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: None (ShivaChen) ChangesThe shift, multiplier, inputZp, and outputZp can be either constant or non-constant, depending on whether dynamic extension is enabled. When these values are non-constant, they are added as inputs to linalg::GenericOp, and corresponding affine maps are appended to the indexingMaps. Patch is 20.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155967.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 73046e0da361a..cc1289f397dff 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1342,6 +1342,186 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
}
};
+// Collapse tensor<1xiN> into tensor<iN>
+// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
+static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
+ Location loc) {
+ SmallVector<ReassociationExprs, 1> reassociation;
+ // Create the collapsed type
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto elemType = inputType.getElementType();
+ auto collapsedType = RankedTensorType::get({}, elemType);
+ // Emit the collapse op
+ return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
+ reassociation);
+}
+
+// The multiplier may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the multiplier is non-constant, add it as an input to linalg::GenericOp
+// by:
+// 1. Pushing it into 'genericInputs'.
+// 2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the multiplier is constant, set 'multiplierConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+ PatternRewriter &rewriter, llvm::SmallVector<int32_t> &multiplierValues,
+ SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+ bool isConstant, tosa::RescaleOp op, Value &multiplierConstant) {
+
+ auto loc = op.getLoc();
+ auto inputTy = cast<ShapedType>(op.getInput().getType());
+ unsigned rank = inputTy.getRank();
+ SmallVector<AffineExpr, 2> multiplierExprs{
+ rewriter.getAffineDimExpr(rank - 1)};
+
+ if (isConstant) {
+ // If we are rescaling per-channel then we need to store the multiplier
+ // values in a buffer.
+ if (multiplierValues.size() == 1) {
+ multiplierConstant = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
+ } else {
+ auto multiplierType =
+ RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
+ rewriter.getI32Type());
+ genericInputs.push_back(rewriter.create<arith::ConstantOp>(
+ loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
+
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, multiplierExprs,
+ rewriter.getContext()));
+ }
+ } else {
+ if (op.getMultiplier().getType().getRank() == 1) {
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap =
+ AffineMap::get(rank, 0, {}, rewriter.getContext());
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op.getMultiplier(), loc));
+ indexingMaps.push_back(broadcastMap);
+ } else {
+ genericInputs.push_back(op.getMultiplier());
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, multiplierExprs,
+ rewriter.getContext()));
+ }
+ }
+}
+
+// The shift may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the shift is non-constant, add it as an input to linalg::GenericOp by:
+// 1. Pushing it into 'genericInputs'.
+// 2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the shift is constant, set 'shiftConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForShift(
+ PatternRewriter &rewriter, llvm::SmallVector<int8_t> &shiftValues,
+ SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+ bool isConstant, tosa::RescaleOp op, Value &shiftConstant) {
+
+ auto loc = op.getLoc();
+ auto inputTy = cast<ShapedType>(op.getInput().getType());
+ unsigned rank = inputTy.getRank();
+ SmallVector<AffineExpr, 2> shiftExprs = {rewriter.getAffineDimExpr(rank - 1)};
+
+ if (isConstant) {
+ // If we are rescaling per-channel then we need to store the shift
+ // values in a buffer.
+ if (shiftValues.size() == 1) {
+ shiftConstant = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI8IntegerAttr(shiftValues.front()));
+ } else {
+ auto shiftType =
+ RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
+ rewriter.getIntegerType(8));
+ genericInputs.push_back(rewriter.create<arith::ConstantOp>(
+ loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, shiftExprs,
+ rewriter.getContext()));
+ }
+ } else {
+ if (op.getShift().getType().getRank() == 1) {
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap =
+ AffineMap::get(rank, 0, {}, rewriter.getContext());
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op.getShift(), loc));
+ indexingMaps.push_back(broadcastMap);
+ } else {
+ genericInputs.push_back(op.getShift());
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, shiftExprs,
+ rewriter.getContext()));
+ }
+ }
+}
+
+// Return the extended Zp to be used in subsequent arithmetic operations.
+static Value getExtendInputZp(OpBuilder &builder, Type valueTy,
+ FailureOr<int64_t> maybeZp, Location loc,
+ ValueRange blockArgs) {
+ Value result;
+ // The Zp value can be either constant or non-constant, depending on
+ // whether dynamic extension is enabled.
+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+ // be passed as an input to linalg::GenericOp.
+ if (failed(maybeZp)) {
+ result = blockArgs[3];
+ auto zpTy = result.getType();
+ if (zpTy.getIntOrFloatBitWidth() < 32) {
+ if (zpTy.isUnsignedInteger()) {
+ result =
+ builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+ } else {
+ result =
+ builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+ }
+ }
+ } else {
+ const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
+ // Extend zeropoint for sub-32bits widths.
+ const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32;
+ result = builder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+ }
+ return result;
+}
+
+// Return the i32 outputZp to be used in subsequent arithmetic operations.
+static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
+ FailureOr<int64_t> maybeZp, Location loc,
+ ValueRange blockArgs) {
+ Value result;
+ // The Zp value can be either constant or non-constant, depending on
+ // whether dynamic extension is enabled.
+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+ // be passed as an input to linalg::GenericOp.
+ if (failed(maybeZp)) {
+ result = blockArgs[4];
+ auto zpTy = result.getType();
+ if (zpTy.getIntOrFloatBitWidth() < 32) {
+ if (zpTy.isUnsignedInteger()) {
+ result =
+ builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+ } else {
+ result =
+ builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+ }
+ } else if (zpTy.getIntOrFloatBitWidth() > 32) {
+ result =
+ builder.create<arith::TruncIOp>(loc, builder.getI32Type(), result);
+ }
+ } else {
+ const int32_t attrBitwidth = 32;
+ result = builder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+ }
+ return result;
+}
+
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
public:
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1375,41 +1555,45 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// The shift and multiplier values.
DenseElementsAttr shiftElems;
- if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
- return rewriter.notifyMatchFailure(
- op, "tosa.rescale requires constant shift input values");
+ bool isShiftConstant = false;
+ if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
+ isShiftConstant = true;
DenseElementsAttr multiplierElems;
- if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
- return rewriter.notifyMatchFailure(
- op, "tosa.rescale requires constant multiplier input values");
-
- llvm::SmallVector<int8_t> shiftValues =
- llvm::to_vector(shiftElems.getValues<int8_t>());
- // explicit cast is required here
- llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
- llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
- [](IntegerAttr attr) -> int32_t {
- return static_cast<int32_t>(attr.getInt());
- }));
-
- // If we shift by more than the bitwidth, this just sets to 0.
- for (int i = 0, s = multiplierValues.size(); i < s; i++) {
- if (shiftValues[i] > 63) {
- shiftValues[i] = 0;
- multiplierValues[i] = 0;
+ bool isMultiplierConstant = false;
+ if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+ isMultiplierConstant = true;
+
+ llvm::SmallVector<int8_t> shiftValues;
+ llvm::SmallVector<int32_t> multiplierValues;
+ StringAttr roundingMode;
+ bool doubleRound;
+
+ if (isMultiplierConstant && isShiftConstant) {
+ shiftValues = llvm::to_vector(shiftElems.getValues<int8_t>());
+ // explicit cast is required here
+ multiplierValues = llvm::to_vector(
+ llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
+ [](IntegerAttr attr) -> int32_t {
+ return static_cast<int32_t>(attr.getInt());
+ }));
+
+ // If we shift by more than the bitwidth, this just sets to 0.
+ for (int i = 0, s = multiplierValues.size(); i < s; i++) {
+ if (shiftValues[i] > 63) {
+ shiftValues[i] = 0;
+ multiplierValues[i] = 0;
+ }
}
- }
-
- // Double round only occurs if shift is greater than 31, check that this
- // is ever true.
+ // Double round only occurs if shift is greater than 31, check that this
+ // is ever true.
+ doubleRound = op.getRoundingMode() == "DOUBLE_ROUND" &&
+ llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+ } else
+ doubleRound = op.getRoundingMode() == "DOUBLE_ROUND";
- bool doubleRound =
- op.getRoundingMode() == "DOUBLE_ROUND" &&
- llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
- StringAttr roundingMode = doubleRound
- ? rewriter.getStringAttr("DOUBLE_ROUND")
- : rewriter.getStringAttr("SINGLE_ROUND");
+ roundingMode = doubleRound ? rewriter.getStringAttr("DOUBLE_ROUND")
+ : rewriter.getStringAttr("SINGLE_ROUND");
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(rank)};
@@ -1418,46 +1602,35 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// If we are rescaling per-channel then we need to store the multiplier
// values in a buffer.
Value multiplierConstant;
- int64_t multiplierArg = 0;
- if (multiplierValues.size() == 1) {
- multiplierConstant = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
- } else {
- SmallVector<AffineExpr, 2> multiplierExprs{
- rewriter.getAffineDimExpr(rank - 1)};
- auto multiplierType =
- RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
- rewriter.getI32Type());
- genericInputs.push_back(arith::ConstantOp::create(
- rewriter, loc,
- DenseIntElementsAttr::get(multiplierType, multiplierValues)));
-
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
- /*symbolCount=*/0, multiplierExprs,
- rewriter.getContext()));
-
- multiplierArg = indexingMaps.size() - 1;
- }
-
+ setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+ rewriter, multiplierValues, genericInputs, indexingMaps,
+ isMultiplierConstant, op, multiplierConstant);
// If we are rescaling per-channel then we need to store the shift
// values in a buffer.
Value shiftConstant;
- int64_t shiftArg = 0;
- if (shiftValues.size() == 1) {
- shiftConstant = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
- } else {
- SmallVector<AffineExpr, 2> shiftExprs = {
- rewriter.getAffineDimExpr(rank - 1)};
- auto shiftType =
- RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
- rewriter.getIntegerType(8));
- genericInputs.push_back(arith::ConstantOp::create(
- rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
- /*symbolCount=*/0, shiftExprs,
- rewriter.getContext()));
- shiftArg = indexingMaps.size() - 1;
+ setupLinalgGenericOpInputAndIndexingMapForShift(
+ rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
+ shiftConstant);
+
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
+ FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+ FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+ // The inputZp and outputZp may be either constant or non-constant,
+ // depending on whether dynamic extension is enabled.
+ // - If the zp is non-constant, add it as an input to linalg::GenericOp by:
+ // 1. Pushing it into 'genericInputs'.
+ // 2. Appending a corresponding affine map to 'indexingMaps'.
+ if (failed(maybeIZp)) {
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
+ indexingMaps.push_back(broadcastMap);
+ }
+ if (failed(maybeOZp)) {
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
+ indexingMaps.push_back(broadcastMap);
}
// Indexing maps for output values.
@@ -1477,40 +1650,20 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
Type valueTy = value.getType();
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
- if (failed(maybeIZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "input zero point cannot be statically determined");
- return;
- }
-
- const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
- // Extend zeropoint for sub-32bits widths.
- const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
- auto inputZp = arith::ConstantOp::create(
- nestedBuilder, loc,
- IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
- *maybeIZp));
-
+ auto inputZp = getExtendInputZp(nestedBuilder, valueTy, maybeIZp,
+ nestedLoc, blockArgs);
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
- if (failed(maybeOZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "output zero point cannot be statically determined");
- return;
- };
+ auto outputZp = getI32OutputZp(nestedBuilder, valueTy, maybeOZp,
+ nestedLoc, blockArgs);
IntegerType outIntType =
cast<IntegerType>(blockArgs.back().getType());
unsigned outBitWidth = outIntType.getWidth();
- const int32_t outAttrBitwidth = 32;
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
- auto outputZp = arith::ConstantOp::create(
- nestedBuilder, loc,
- IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
- *maybeOZp));
-
- Value multiplier = multiplierConstant ? multiplierConstant
- : blockArgs[multiplierArg];
- Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
+
+ Value multiplier =
+ multiplierConstant ? multiplierConstant : blockArgs[1];
+ Value shift = shiftConstant ? shiftConstant : blockArgs[2];
if (valueTy.isUnsignedInteger()) {
value = UnrealizedConversionCastOp::create(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index aee0caa91043d..8313173e1fec9 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1478,6 +1478,34 @@ func.func @unsupportedRescaleInexactRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>
// -----
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: @rescale_no_const
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @rescale_no_const(%arg0 : tensor<2xi8>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) {
+ // CHECK: [[MULTIPLIER:%.+]] = tensor.collapse_shape %arg1 [] : tensor<1xi32> into tensor<i32>
+ // CHECK: [[SHIFT:%.+]] = tensor.collapse_shape %arg2 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[MULTIPLIER]], [[SHIFT]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<i32>, tensor<i8>, tensor<i8>, tensor<i8>) outs([[INIT]] : tensor<2xi8>) {
+ // CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8):
+ // CHECK: [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
+ // CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32
+ // CHECK: [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
+ // CHECK: [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
+ // CHECK: [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = "DOUBLE_ROUND"} : (i32, i32, i8) -> i32
+ // CHECK: [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
+ // CHECK: %c-128_i32 = arith.constant -128 : i32
+ // CHECK: %c127_i32 = arith.constant 127 : i32
+ // CHECK: [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32
+ // CHECK: [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false}...
[truncated]
|
@llvm/pr-subscribers-mlir-tosa Author: None (ShivaChen) ChangesThe shift, multiplier, inputZp, and outputZp can be either constant or non-constant, depending on whether dynamic extension is enabled. When these values are non-constant, they are added as inputs to linalg::GenericOp, and corresponding affine maps are appended to the indexingMaps. Patch is 20.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155967.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 73046e0da361a..cc1289f397dff 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1342,6 +1342,186 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
}
};
+// Collapse tensor<1xiN> into tensor<iN>
+// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
+static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
+ Location loc) {
+ SmallVector<ReassociationExprs, 1> reassociation;
+ // Create the collapsed type
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto elemType = inputType.getElementType();
+ auto collapsedType = RankedTensorType::get({}, elemType);
+ // Emit the collapse op
+ return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
+ reassociation);
+}
+
+// The multiplier may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the multiplier is non-constant, add it as an input to linalg::GenericOp
+// by:
+// 1. Pushing it into 'genericInputs'.
+// 2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the multiplier is constant, set 'multiplierConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+ PatternRewriter &rewriter, llvm::SmallVector<int32_t> &multiplierValues,
+ SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+ bool isConstant, tosa::RescaleOp op, Value &multiplierConstant) {
+
+ auto loc = op.getLoc();
+ auto inputTy = cast<ShapedType>(op.getInput().getType());
+ unsigned rank = inputTy.getRank();
+ SmallVector<AffineExpr, 2> multiplierExprs{
+ rewriter.getAffineDimExpr(rank - 1)};
+
+ if (isConstant) {
+ // If we are rescaling per-channel then we need to store the multiplier
+ // values in a buffer.
+ if (multiplierValues.size() == 1) {
+ multiplierConstant = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
+ } else {
+ auto multiplierType =
+ RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
+ rewriter.getI32Type());
+ genericInputs.push_back(rewriter.create<arith::ConstantOp>(
+ loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
+
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, multiplierExprs,
+ rewriter.getContext()));
+ }
+ } else {
+ if (op.getMultiplier().getType().getRank() == 1) {
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap =
+ AffineMap::get(rank, 0, {}, rewriter.getContext());
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op.getMultiplier(), loc));
+ indexingMaps.push_back(broadcastMap);
+ } else {
+ genericInputs.push_back(op.getMultiplier());
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, multiplierExprs,
+ rewriter.getContext()));
+ }
+ }
+}
+
+// The shift may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the shift is non-constant, add it as an input to linalg::GenericOp by:
+// 1. Pushing it into 'genericInputs'.
+// 2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the shift is constant, set 'shiftConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForShift(
+ PatternRewriter &rewriter, llvm::SmallVector<int8_t> &shiftValues,
+ SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+ bool isConstant, tosa::RescaleOp op, Value &shiftConstant) {
+
+ auto loc = op.getLoc();
+ auto inputTy = cast<ShapedType>(op.getInput().getType());
+ unsigned rank = inputTy.getRank();
+ SmallVector<AffineExpr, 2> shiftExprs = {rewriter.getAffineDimExpr(rank - 1)};
+
+ if (isConstant) {
+ // If we are rescaling per-channel then we need to store the shift
+ // values in a buffer.
+ if (shiftValues.size() == 1) {
+ shiftConstant = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI8IntegerAttr(shiftValues.front()));
+ } else {
+ auto shiftType =
+ RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
+ rewriter.getIntegerType(8));
+ genericInputs.push_back(rewriter.create<arith::ConstantOp>(
+ loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, shiftExprs,
+ rewriter.getContext()));
+ }
+ } else {
+ if (op.getShift().getType().getRank() == 1) {
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap =
+ AffineMap::get(rank, 0, {}, rewriter.getContext());
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op.getShift(), loc));
+ indexingMaps.push_back(broadcastMap);
+ } else {
+ genericInputs.push_back(op.getShift());
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, shiftExprs,
+ rewriter.getContext()));
+ }
+ }
+}
+
+// Return the extended Zp to be used in subsequent arithmetic operations.
+static Value getExtendInputZp(OpBuilder &builder, Type valueTy,
+ FailureOr<int64_t> maybeZp, Location loc,
+ ValueRange blockArgs) {
+ Value result;
+ // The Zp value can be either constant or non-constant, depending on
+ // whether dynamic extension is enabled.
+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+ // be passed as an input to linalg::GenericOp.
+ if (failed(maybeZp)) {
+ result = blockArgs[3];
+ auto zpTy = result.getType();
+ if (zpTy.getIntOrFloatBitWidth() < 32) {
+ if (zpTy.isUnsignedInteger()) {
+ result =
+ builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+ } else {
+ result =
+ builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+ }
+ }
+ } else {
+ const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
+ // Extend zeropoint for sub-32bits widths.
+ const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32;
+ result = builder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+ }
+ return result;
+}
+
+// Return the i32 outputZp to be used in subsequent arithmetic operations.
+static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
+ FailureOr<int64_t> maybeZp, Location loc,
+ ValueRange blockArgs) {
+ Value result;
+ // The Zp value can be either constant or non-constant, depending on
+ // whether dynamic extension is enabled.
+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+ // be passed as an input to linalg::GenericOp.
+ if (failed(maybeZp)) {
+ result = blockArgs[4];
+ auto zpTy = result.getType();
+ if (zpTy.getIntOrFloatBitWidth() < 32) {
+ if (zpTy.isUnsignedInteger()) {
+ result =
+ builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+ } else {
+ result =
+ builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+ }
+ } else if (zpTy.getIntOrFloatBitWidth() > 32) {
+ result =
+ builder.create<arith::TruncIOp>(loc, builder.getI32Type(), result);
+ }
+ } else {
+ const int32_t attrBitwidth = 32;
+ result = builder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+ }
+ return result;
+}
+
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
public:
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1375,41 +1555,45 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// The shift and multiplier values.
DenseElementsAttr shiftElems;
- if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
- return rewriter.notifyMatchFailure(
- op, "tosa.rescale requires constant shift input values");
+ bool isShiftConstant = false;
+ if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
+ isShiftConstant = true;
DenseElementsAttr multiplierElems;
- if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
- return rewriter.notifyMatchFailure(
- op, "tosa.rescale requires constant multiplier input values");
-
- llvm::SmallVector<int8_t> shiftValues =
- llvm::to_vector(shiftElems.getValues<int8_t>());
- // explicit cast is required here
- llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
- llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
- [](IntegerAttr attr) -> int32_t {
- return static_cast<int32_t>(attr.getInt());
- }));
-
- // If we shift by more than the bitwidth, this just sets to 0.
- for (int i = 0, s = multiplierValues.size(); i < s; i++) {
- if (shiftValues[i] > 63) {
- shiftValues[i] = 0;
- multiplierValues[i] = 0;
+ bool isMultiplierConstant = false;
+ if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+ isMultiplierConstant = true;
+
+ llvm::SmallVector<int8_t> shiftValues;
+ llvm::SmallVector<int32_t> multiplierValues;
+ StringAttr roundingMode;
+ bool doubleRound;
+
+ if (isMultiplierConstant && isShiftConstant) {
+ shiftValues = llvm::to_vector(shiftElems.getValues<int8_t>());
+ // explicit cast is required here
+ multiplierValues = llvm::to_vector(
+ llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
+ [](IntegerAttr attr) -> int32_t {
+ return static_cast<int32_t>(attr.getInt());
+ }));
+
+ // If we shift by more than the bitwidth, this just sets to 0.
+ for (int i = 0, s = multiplierValues.size(); i < s; i++) {
+ if (shiftValues[i] > 63) {
+ shiftValues[i] = 0;
+ multiplierValues[i] = 0;
+ }
}
- }
-
- // Double round only occurs if shift is greater than 31, check that this
- // is ever true.
+ // Double round only occurs if shift is greater than 31, check that this
+ // is ever true.
+ doubleRound = op.getRoundingMode() == "DOUBLE_ROUND" &&
+ llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+ } else
+ doubleRound = op.getRoundingMode() == "DOUBLE_ROUND";
- bool doubleRound =
- op.getRoundingMode() == "DOUBLE_ROUND" &&
- llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
- StringAttr roundingMode = doubleRound
- ? rewriter.getStringAttr("DOUBLE_ROUND")
- : rewriter.getStringAttr("SINGLE_ROUND");
+ roundingMode = doubleRound ? rewriter.getStringAttr("DOUBLE_ROUND")
+ : rewriter.getStringAttr("SINGLE_ROUND");
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(rank)};
@@ -1418,46 +1602,35 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// If we are rescaling per-channel then we need to store the multiplier
// values in a buffer.
Value multiplierConstant;
- int64_t multiplierArg = 0;
- if (multiplierValues.size() == 1) {
- multiplierConstant = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
- } else {
- SmallVector<AffineExpr, 2> multiplierExprs{
- rewriter.getAffineDimExpr(rank - 1)};
- auto multiplierType =
- RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
- rewriter.getI32Type());
- genericInputs.push_back(arith::ConstantOp::create(
- rewriter, loc,
- DenseIntElementsAttr::get(multiplierType, multiplierValues)));
-
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
- /*symbolCount=*/0, multiplierExprs,
- rewriter.getContext()));
-
- multiplierArg = indexingMaps.size() - 1;
- }
-
+ setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+ rewriter, multiplierValues, genericInputs, indexingMaps,
+ isMultiplierConstant, op, multiplierConstant);
// If we are rescaling per-channel then we need to store the shift
// values in a buffer.
Value shiftConstant;
- int64_t shiftArg = 0;
- if (shiftValues.size() == 1) {
- shiftConstant = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
- } else {
- SmallVector<AffineExpr, 2> shiftExprs = {
- rewriter.getAffineDimExpr(rank - 1)};
- auto shiftType =
- RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
- rewriter.getIntegerType(8));
- genericInputs.push_back(arith::ConstantOp::create(
- rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
- /*symbolCount=*/0, shiftExprs,
- rewriter.getContext()));
- shiftArg = indexingMaps.size() - 1;
+ setupLinalgGenericOpInputAndIndexingMapForShift(
+ rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
+ shiftConstant);
+
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
+ FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+ FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+ // The inputZp and outputZp may be either constant or non-constant,
+ // depending on whether dynamic extension is enabled.
+ // - If the zp is non-constant, add it as an input to linalg::GenericOp by:
+ // 1. Pushing it into 'genericInputs'.
+ // 2. Appending a corresponding affine map to 'indexingMaps'.
+ if (failed(maybeIZp)) {
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
+ indexingMaps.push_back(broadcastMap);
+ }
+ if (failed(maybeOZp)) {
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
+ indexingMaps.push_back(broadcastMap);
}
// Indexing maps for output values.
@@ -1477,40 +1650,20 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
Type valueTy = value.getType();
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
- if (failed(maybeIZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "input zero point cannot be statically determined");
- return;
- }
-
- const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
- // Extend zeropoint for sub-32bits widths.
- const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
- auto inputZp = arith::ConstantOp::create(
- nestedBuilder, loc,
- IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
- *maybeIZp));
-
+ auto inputZp = getExtendInputZp(nestedBuilder, valueTy, maybeIZp,
+ nestedLoc, blockArgs);
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
- if (failed(maybeOZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "output zero point cannot be statically determined");
- return;
- };
+ auto outputZp = getI32OutputZp(nestedBuilder, valueTy, maybeOZp,
+ nestedLoc, blockArgs);
IntegerType outIntType =
cast<IntegerType>(blockArgs.back().getType());
unsigned outBitWidth = outIntType.getWidth();
- const int32_t outAttrBitwidth = 32;
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
- auto outputZp = arith::ConstantOp::create(
- nestedBuilder, loc,
- IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
- *maybeOZp));
-
- Value multiplier = multiplierConstant ? multiplierConstant
- : blockArgs[multiplierArg];
- Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
+
+ Value multiplier =
+ multiplierConstant ? multiplierConstant : blockArgs[1];
+ Value shift = shiftConstant ? shiftConstant : blockArgs[2];
if (valueTy.isUnsignedInteger()) {
value = UnrealizedConversionCastOp::create(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index aee0caa91043d..8313173e1fec9 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1478,6 +1478,34 @@ func.func @unsupportedRescaleInexactRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>
// -----
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: @rescale_no_const
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @rescale_no_const(%arg0 : tensor<2xi8>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) {
+ // CHECK: [[MULTIPLIER:%.+]] = tensor.collapse_shape %arg1 [] : tensor<1xi32> into tensor<i32>
+ // CHECK: [[SHIFT:%.+]] = tensor.collapse_shape %arg2 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[MULTIPLIER]], [[SHIFT]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<i32>, tensor<i8>, tensor<i8>, tensor<i8>) outs([[INIT]] : tensor<2xi8>) {
+ // CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8):
+ // CHECK: [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
+ // CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32
+ // CHECK: [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
+ // CHECK: [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
+ // CHECK: [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = "DOUBLE_ROUND"} : (i32, i32, i8) -> i32
+ // CHECK: [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
+ // CHECK: %c-128_i32 = arith.constant -128 : i32
+ // CHECK: %c127_i32 = arith.constant 127 : i32
+ // CHECK: [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32
+ // CHECK: [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false}...
[truncated]
|
Apologies for not getting around to reviewing yet @ShivaChen, I just wanted to note that this change will need rebasing on-top of #155695 |
5bbbc3d
to
0baeef5
Compare
Hi @lhutton1, I have rebased on-top of #155695. Thanks for noting me the code base changes. :-) |
a1777d8
to
41fc88d
Compare
gentle ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ShivaChen, the changes look good to me overall, however, I'm not very familiar with the tosa-to-linalg lowering - perhaps @FranklandJack @GeorgeARM would be interested in taking a look?
Are you planning to contribute more ext-dyanamic support for this lowering? I'm wondering if it would be helpful to create an issue to track the progress?
Hi @lhutton1, I have created issue as #157378. Hi @FranklandJack @GeorgeARM, |
Thanks for creating the issue. I was actually thinking of more of a checklist for ext-dynamic support in tosa-to-linalg. e.g. non-const mul shift was supported in #155197, this PR supports non-const shift, multiplier, input_zp, and output_zp. In the future we could look at supporting the rest of the ext-dynamic changes (https://www.mlplatform.org/tosa/tosa_spec_1_0_0.html#_ext_dynamic_extension). An issue will help to track this progress |
Hi @lhutton1, it seems to be a good idea. I have updated #157378. Feel free to tweak the description of the issue. Thank you very much for your guidance. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for the delay in reviewing this @ShivaChen
// 1. Pushing it into 'genericInputs'. | ||
// 2. Appending a corresponding affine map to 'indexingMaps'. | ||
// - If the shift is constant, set 'shiftConstant' instead. | ||
static void setupLinalgGenericOpInputAndIndexingMapForShift( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This appears to be mostly duplicated code to the function above with the exception of the datatype of the values passed in and the op operand (multiplier or shift). Is it possible to unify these to help reduce the duplication?
builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result); | ||
} | ||
} else if (zpTy.getIntOrFloatBitWidth() > 32) { | ||
result = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to add a test for this case?
if (zpTy.getIntOrFloatBitWidth() < 32) { | ||
if (zpTy.isUnsignedInteger()) { | ||
result = | ||
builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this case was covered by the tests, is it possible to add one?
result = builder.create<arith::ConstantOp>( | ||
loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp)); | ||
} | ||
return result; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: based on the logic here, I think some of the nesting could be removed by returning result directly?
} else { | ||
const int32_t bitwidth = valueTy.getIntOrFloatBitWidth(); | ||
// Extend zeropoint for sub-32bits widths. | ||
const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just noticed that this doesn't align with the pesudo-code implementation in the spec which extends the inputs and input zp to int48
(https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_rescale). This doesn't seem to have been introduced in this PR, but it's something we might want to review in the future. (cc @GeorgeARM @RoboTux)
The shift, multiplier, inputZp, and outputZp can be either constant or non-constant, depending on whether dynamic extension is enabled. When these values are non-constant, they are added as inputs to linalg::GenericOp, and corresponding affine maps are appended to the indexingMaps.
OutputZP width > 32 is not allow by SPEC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many thanks for your helpful comments to reshape the code.
09912e4
to
8acc72c
Compare
70bf8d1
to
d9bcd1e
Compare
gentle ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for only just getting back to this @ShivaChen, LGTM!
Thanks for your kind and thorough review. :-) |
The shift, multiplier, inputZp, and outputZp can be either constant or non-constant, depending on whether dynamic extension is enabled.
When these values are non-constant, they are added as inputs to linalg::GenericOp, and corresponding affine maps are appended to the indexingMaps.
The commit help to pass following Tosa conformance tests.
rescale_22x20_i32_outi8_sc0_rmS_pc0_iu0_ou0_dyn
rescale_31x18_i8_outi8_sc0_rmS_pc0_iu1_ou0_dyn
rescale_20x19_i16_outi8_sc0_rmS_pc0_iu1_ou0_dyn