Skip to content

Conversation

ShivaChen
Copy link
Collaborator

@ShivaChen ShivaChen commented Aug 29, 2025

The shift, multiplier, inputZp, and outputZp can be either constant or non-constant, depending on whether dynamic extension is enabled.

When these values are non-constant, they are added as inputs to linalg::GenericOp, and corresponding affine maps are appended to the indexingMaps.

The commit help to pass following Tosa conformance tests.
rescale_22x20_i32_outi8_sc0_rmS_pc0_iu0_ou0_dyn
rescale_31x18_i8_outi8_sc0_rmS_pc0_iu1_ou0_dyn
rescale_20x19_i16_outi8_sc0_rmS_pc0_iu1_ou0_dyn

@llvmbot
Copy link
Member

llvmbot commented Aug 29, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (ShivaChen)

Changes

The shift, multiplier, inputZp, and outputZp can be either constant or non-constant, depending on whether dynamic extension is enabled.

When these values are non-constant, they are added as inputs to linalg::GenericOp, and corresponding affine maps are appended to the indexingMaps.


Patch is 20.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155967.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+249-96)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+28)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 73046e0da361a..cc1289f397dff 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1342,6 +1342,186 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
   }
 };
 
+// Collapse tensor<1xiN> into tensor<iN>
+// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
+static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
+                                  Location loc) {
+  SmallVector<ReassociationExprs, 1> reassociation;
+  // Create the collapsed type
+  auto inputType = cast<RankedTensorType>(input.getType());
+  auto elemType = inputType.getElementType();
+  auto collapsedType = RankedTensorType::get({}, elemType);
+  // Emit the collapse op
+  return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
+                                                  reassociation);
+}
+
+// The multiplier may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the multiplier is non-constant, add it as an input to linalg::GenericOp
+// by:
+//     1. Pushing it into 'genericInputs'.
+//     2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the multiplier is constant, set 'multiplierConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+    PatternRewriter &rewriter, llvm::SmallVector<int32_t> &multiplierValues,
+    SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+    bool isConstant, tosa::RescaleOp op, Value &multiplierConstant) {
+
+  auto loc = op.getLoc();
+  auto inputTy = cast<ShapedType>(op.getInput().getType());
+  unsigned rank = inputTy.getRank();
+  SmallVector<AffineExpr, 2> multiplierExprs{
+      rewriter.getAffineDimExpr(rank - 1)};
+
+  if (isConstant) {
+    // If we are rescaling per-channel then we need to store the multiplier
+    // values in a buffer.
+    if (multiplierValues.size() == 1) {
+      multiplierConstant = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
+    } else {
+      auto multiplierType =
+          RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
+                                rewriter.getI32Type());
+      genericInputs.push_back(rewriter.create<arith::ConstantOp>(
+          loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
+
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, multiplierExprs,
+                                            rewriter.getContext()));
+    }
+  } else {
+    if (op.getMultiplier().getType().getRank() == 1) {
+      // broadcastMap = affine_map<(d0, d1) -> ()>
+      // It would affect as broadcast for scalar values in linalg::GenericOp.
+      AffineMap broadcastMap =
+          AffineMap::get(rank, 0, {}, rewriter.getContext());
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op.getMultiplier(), loc));
+      indexingMaps.push_back(broadcastMap);
+    } else {
+      genericInputs.push_back(op.getMultiplier());
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, multiplierExprs,
+                                            rewriter.getContext()));
+    }
+  }
+}
+
+// The shift may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the shift is non-constant, add it as an input to linalg::GenericOp by:
+//     1. Pushing it into 'genericInputs'.
+//     2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the shift is constant, set 'shiftConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForShift(
+    PatternRewriter &rewriter, llvm::SmallVector<int8_t> &shiftValues,
+    SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+    bool isConstant, tosa::RescaleOp op, Value &shiftConstant) {
+
+  auto loc = op.getLoc();
+  auto inputTy = cast<ShapedType>(op.getInput().getType());
+  unsigned rank = inputTy.getRank();
+  SmallVector<AffineExpr, 2> shiftExprs = {rewriter.getAffineDimExpr(rank - 1)};
+
+  if (isConstant) {
+    // If we are rescaling per-channel then we need to store the shift
+    // values in a buffer.
+    if (shiftValues.size() == 1) {
+      shiftConstant = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI8IntegerAttr(shiftValues.front()));
+    } else {
+      auto shiftType =
+          RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
+                                rewriter.getIntegerType(8));
+      genericInputs.push_back(rewriter.create<arith::ConstantOp>(
+          loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, shiftExprs,
+                                            rewriter.getContext()));
+    }
+  } else {
+    if (op.getShift().getType().getRank() == 1) {
+      // broadcastMap = affine_map<(d0, d1) -> ()>
+      // It would affect as broadcast for scalar values in linalg::GenericOp.
+      AffineMap broadcastMap =
+          AffineMap::get(rank, 0, {}, rewriter.getContext());
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op.getShift(), loc));
+      indexingMaps.push_back(broadcastMap);
+    } else {
+      genericInputs.push_back(op.getShift());
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, shiftExprs,
+                                            rewriter.getContext()));
+    }
+  }
+}
+
+// Return the extended Zp to be used in subsequent arithmetic operations.
+static Value getExtendInputZp(OpBuilder &builder, Type valueTy,
+                              FailureOr<int64_t> maybeZp, Location loc,
+                              ValueRange blockArgs) {
+  Value result;
+  // The Zp value can be either constant or non-constant, depending on
+  // whether dynamic extension is enabled.
+  // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+  // be passed as an input to linalg::GenericOp.
+  if (failed(maybeZp)) {
+    result = blockArgs[3];
+    auto zpTy = result.getType();
+    if (zpTy.getIntOrFloatBitWidth() < 32) {
+      if (zpTy.isUnsignedInteger()) {
+        result =
+            builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+      } else {
+        result =
+            builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+      }
+    }
+  } else {
+    const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
+    // Extend zeropoint for sub-32bits widths.
+    const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32;
+    result = builder.create<arith::ConstantOp>(
+        loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+  }
+  return result;
+}
+
+// Return the i32 outputZp to be used in subsequent arithmetic operations.
+static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
+                            FailureOr<int64_t> maybeZp, Location loc,
+                            ValueRange blockArgs) {
+  Value result;
+  // The Zp value can be either constant or non-constant, depending on
+  // whether dynamic extension is enabled.
+  // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+  // be passed as an input to linalg::GenericOp.
+  if (failed(maybeZp)) {
+    result = blockArgs[4];
+    auto zpTy = result.getType();
+    if (zpTy.getIntOrFloatBitWidth() < 32) {
+      if (zpTy.isUnsignedInteger()) {
+        result =
+            builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+      } else {
+        result =
+            builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+      }
+    } else if (zpTy.getIntOrFloatBitWidth() > 32) {
+      result =
+          builder.create<arith::TruncIOp>(loc, builder.getI32Type(), result);
+    }
+  } else {
+    const int32_t attrBitwidth = 32;
+    result = builder.create<arith::ConstantOp>(
+        loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+  }
+  return result;
+}
+
 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
 public:
   using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1375,41 +1555,45 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
 
     // The shift and multiplier values.
     DenseElementsAttr shiftElems;
-    if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
-      return rewriter.notifyMatchFailure(
-          op, "tosa.rescale requires constant shift input values");
+    bool isShiftConstant = false;
+    if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
+      isShiftConstant = true;
 
     DenseElementsAttr multiplierElems;
-    if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
-      return rewriter.notifyMatchFailure(
-          op, "tosa.rescale requires constant multiplier input values");
-
-    llvm::SmallVector<int8_t> shiftValues =
-        llvm::to_vector(shiftElems.getValues<int8_t>());
-    // explicit cast is required here
-    llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
-        llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
-                        [](IntegerAttr attr) -> int32_t {
-                          return static_cast<int32_t>(attr.getInt());
-                        }));
-
-    // If we shift by more than the bitwidth, this just sets to 0.
-    for (int i = 0, s = multiplierValues.size(); i < s; i++) {
-      if (shiftValues[i] > 63) {
-        shiftValues[i] = 0;
-        multiplierValues[i] = 0;
+    bool isMultiplierConstant = false;
+    if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+      isMultiplierConstant = true;
+
+    llvm::SmallVector<int8_t> shiftValues;
+    llvm::SmallVector<int32_t> multiplierValues;
+    StringAttr roundingMode;
+    bool doubleRound;
+
+    if (isMultiplierConstant && isShiftConstant) {
+      shiftValues = llvm::to_vector(shiftElems.getValues<int8_t>());
+      // explicit cast is required here
+      multiplierValues = llvm::to_vector(
+          llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
+                          [](IntegerAttr attr) -> int32_t {
+                            return static_cast<int32_t>(attr.getInt());
+                          }));
+
+      // If we shift by more than the bitwidth, this just sets to 0.
+      for (int i = 0, s = multiplierValues.size(); i < s; i++) {
+        if (shiftValues[i] > 63) {
+          shiftValues[i] = 0;
+          multiplierValues[i] = 0;
+        }
       }
-    }
-
-    // Double round only occurs if shift is greater than 31, check that this
-    // is ever true.
+      // Double round only occurs if shift is greater than 31, check that this
+      // is ever true.
+      doubleRound = op.getRoundingMode() == "DOUBLE_ROUND" &&
+                    llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+    } else
+      doubleRound = op.getRoundingMode() == "DOUBLE_ROUND";
 
-    bool doubleRound =
-        op.getRoundingMode() == "DOUBLE_ROUND" &&
-        llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
-    StringAttr roundingMode = doubleRound
-                                  ? rewriter.getStringAttr("DOUBLE_ROUND")
-                                  : rewriter.getStringAttr("SINGLE_ROUND");
+    roundingMode = doubleRound ? rewriter.getStringAttr("DOUBLE_ROUND")
+                               : rewriter.getStringAttr("SINGLE_ROUND");
 
     SmallVector<AffineMap> indexingMaps = {
         rewriter.getMultiDimIdentityMap(rank)};
@@ -1418,46 +1602,35 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     // If we are rescaling per-channel then we need to store the multiplier
     // values in a buffer.
     Value multiplierConstant;
-    int64_t multiplierArg = 0;
-    if (multiplierValues.size() == 1) {
-      multiplierConstant = arith::ConstantOp::create(
-          rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
-    } else {
-      SmallVector<AffineExpr, 2> multiplierExprs{
-          rewriter.getAffineDimExpr(rank - 1)};
-      auto multiplierType =
-          RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
-                                rewriter.getI32Type());
-      genericInputs.push_back(arith::ConstantOp::create(
-          rewriter, loc,
-          DenseIntElementsAttr::get(multiplierType, multiplierValues)));
-
-      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
-                                            /*symbolCount=*/0, multiplierExprs,
-                                            rewriter.getContext()));
-
-      multiplierArg = indexingMaps.size() - 1;
-    }
-
+    setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+        rewriter, multiplierValues, genericInputs, indexingMaps,
+        isMultiplierConstant, op, multiplierConstant);
     // If we are rescaling per-channel then we need to store the shift
     // values in a buffer.
     Value shiftConstant;
-    int64_t shiftArg = 0;
-    if (shiftValues.size() == 1) {
-      shiftConstant = arith::ConstantOp::create(
-          rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
-    } else {
-      SmallVector<AffineExpr, 2> shiftExprs = {
-          rewriter.getAffineDimExpr(rank - 1)};
-      auto shiftType =
-          RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
-                                rewriter.getIntegerType(8));
-      genericInputs.push_back(arith::ConstantOp::create(
-          rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
-      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
-                                            /*symbolCount=*/0, shiftExprs,
-                                            rewriter.getContext()));
-      shiftArg = indexingMaps.size() - 1;
+    setupLinalgGenericOpInputAndIndexingMapForShift(
+        rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
+        shiftConstant);
+
+    // broadcastMap = affine_map<(d0, d1) -> ()>
+    // It would affect as broadcast for scalar values in linalg::GenericOp.
+    AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
+    FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+    FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+    // The inputZp and outputZp may be either constant or non-constant,
+    // depending on whether dynamic extension is enabled.
+    // - If the zp is non-constant, add it as an input to linalg::GenericOp by:
+    //     1. Pushing it into 'genericInputs'.
+    //     2. Appending a corresponding affine map to 'indexingMaps'.
+    if (failed(maybeIZp)) {
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
+      indexingMaps.push_back(broadcastMap);
+    }
+    if (failed(maybeOZp)) {
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
+      indexingMaps.push_back(broadcastMap);
     }
 
     // Indexing maps for output values.
@@ -1477,40 +1650,20 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           Type valueTy = value.getType();
 
           FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
-          if (failed(maybeIZp)) {
-            (void)rewriter.notifyMatchFailure(
-                op, "input zero point cannot be statically determined");
-            return;
-          }
-
-          const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
-          // Extend zeropoint for sub-32bits widths.
-          const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
-          auto inputZp = arith::ConstantOp::create(
-              nestedBuilder, loc,
-              IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
-                               *maybeIZp));
-
+          auto inputZp = getExtendInputZp(nestedBuilder, valueTy, maybeIZp,
+                                          nestedLoc, blockArgs);
           FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
-          if (failed(maybeOZp)) {
-            (void)rewriter.notifyMatchFailure(
-                op, "output zero point cannot be statically determined");
-            return;
-          };
+          auto outputZp = getI32OutputZp(nestedBuilder, valueTy, maybeOZp,
+                                         nestedLoc, blockArgs);
 
           IntegerType outIntType =
               cast<IntegerType>(blockArgs.back().getType());
           unsigned outBitWidth = outIntType.getWidth();
-          const int32_t outAttrBitwidth = 32;
           assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
-          auto outputZp = arith::ConstantOp::create(
-              nestedBuilder, loc,
-              IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
-                               *maybeOZp));
-
-          Value multiplier = multiplierConstant ? multiplierConstant
-                                                : blockArgs[multiplierArg];
-          Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
+
+          Value multiplier =
+              multiplierConstant ? multiplierConstant : blockArgs[1];
+          Value shift = shiftConstant ? shiftConstant : blockArgs[2];
 
           if (valueTy.isUnsignedInteger()) {
             value = UnrealizedConversionCastOp::create(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index aee0caa91043d..8313173e1fec9 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1478,6 +1478,34 @@ func.func @unsupportedRescaleInexactRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>
 
 // -----
 
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: @rescale_no_const
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @rescale_no_const(%arg0 : tensor<2xi8>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) {
+  // CHECK: [[MULTIPLIER:%.+]] = tensor.collapse_shape %arg1 [] : tensor<1xi32> into tensor<i32>
+  // CHECK: [[SHIFT:%.+]] = tensor.collapse_shape %arg2 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8>
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[MULTIPLIER]], [[SHIFT]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<i32>, tensor<i8>, tensor<i8>, tensor<i8>) outs([[INIT]] : tensor<2xi8>) {
+  // CHECK:   ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8):
+  // CHECK:    [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
+  // CHECK:    [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32
+  // CHECK:    [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
+  // CHECK:    [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
+  // CHECK:    [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = "DOUBLE_ROUND"} : (i32, i32, i8) -> i32
+  // CHECK:    [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
+  // CHECK:    %c-128_i32 = arith.constant -128 : i32
+  // CHECK:    %c127_i32 = arith.constant 127 : i32
+  // CHECK:    [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32
+  // CHECK:    [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false}...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 29, 2025

@llvm/pr-subscribers-mlir-tosa

Author: None (ShivaChen)

Changes

The shift, multiplier, inputZp, and outputZp can be either constant or non-constant, depending on whether dynamic extension is enabled.

When these values are non-constant, they are added as inputs to linalg::GenericOp, and corresponding affine maps are appended to the indexingMaps.


Patch is 20.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155967.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+249-96)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+28)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 73046e0da361a..cc1289f397dff 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1342,6 +1342,186 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
   }
 };
 
+// Collapse tensor<1xiN> into tensor<iN>
+// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
+static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
+                                  Location loc) {
+  SmallVector<ReassociationExprs, 1> reassociation;
+  // Create the collapsed type
+  auto inputType = cast<RankedTensorType>(input.getType());
+  auto elemType = inputType.getElementType();
+  auto collapsedType = RankedTensorType::get({}, elemType);
+  // Emit the collapse op
+  return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
+                                                  reassociation);
+}
+
+// The multiplier may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the multiplier is non-constant, add it as an input to linalg::GenericOp
+// by:
+//     1. Pushing it into 'genericInputs'.
+//     2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the multiplier is constant, set 'multiplierConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+    PatternRewriter &rewriter, llvm::SmallVector<int32_t> &multiplierValues,
+    SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+    bool isConstant, tosa::RescaleOp op, Value &multiplierConstant) {
+
+  auto loc = op.getLoc();
+  auto inputTy = cast<ShapedType>(op.getInput().getType());
+  unsigned rank = inputTy.getRank();
+  SmallVector<AffineExpr, 2> multiplierExprs{
+      rewriter.getAffineDimExpr(rank - 1)};
+
+  if (isConstant) {
+    // If we are rescaling per-channel then we need to store the multiplier
+    // values in a buffer.
+    if (multiplierValues.size() == 1) {
+      multiplierConstant = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
+    } else {
+      auto multiplierType =
+          RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
+                                rewriter.getI32Type());
+      genericInputs.push_back(rewriter.create<arith::ConstantOp>(
+          loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
+
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, multiplierExprs,
+                                            rewriter.getContext()));
+    }
+  } else {
+    if (op.getMultiplier().getType().getRank() == 1) {
+      // broadcastMap = affine_map<(d0, d1) -> ()>
+      // It would affect as broadcast for scalar values in linalg::GenericOp.
+      AffineMap broadcastMap =
+          AffineMap::get(rank, 0, {}, rewriter.getContext());
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op.getMultiplier(), loc));
+      indexingMaps.push_back(broadcastMap);
+    } else {
+      genericInputs.push_back(op.getMultiplier());
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, multiplierExprs,
+                                            rewriter.getContext()));
+    }
+  }
+}
+
+// The shift may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the shift is non-constant, add it as an input to linalg::GenericOp by:
+//     1. Pushing it into 'genericInputs'.
+//     2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the shift is constant, set 'shiftConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForShift(
+    PatternRewriter &rewriter, llvm::SmallVector<int8_t> &shiftValues,
+    SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+    bool isConstant, tosa::RescaleOp op, Value &shiftConstant) {
+
+  auto loc = op.getLoc();
+  auto inputTy = cast<ShapedType>(op.getInput().getType());
+  unsigned rank = inputTy.getRank();
+  SmallVector<AffineExpr, 2> shiftExprs = {rewriter.getAffineDimExpr(rank - 1)};
+
+  if (isConstant) {
+    // If we are rescaling per-channel then we need to store the shift
+    // values in a buffer.
+    if (shiftValues.size() == 1) {
+      shiftConstant = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI8IntegerAttr(shiftValues.front()));
+    } else {
+      auto shiftType =
+          RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
+                                rewriter.getIntegerType(8));
+      genericInputs.push_back(rewriter.create<arith::ConstantOp>(
+          loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, shiftExprs,
+                                            rewriter.getContext()));
+    }
+  } else {
+    if (op.getShift().getType().getRank() == 1) {
+      // broadcastMap = affine_map<(d0, d1) -> ()>
+      // It would affect as broadcast for scalar values in linalg::GenericOp.
+      AffineMap broadcastMap =
+          AffineMap::get(rank, 0, {}, rewriter.getContext());
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op.getShift(), loc));
+      indexingMaps.push_back(broadcastMap);
+    } else {
+      genericInputs.push_back(op.getShift());
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, shiftExprs,
+                                            rewriter.getContext()));
+    }
+  }
+}
+
+// Return the extended Zp to be used in subsequent arithmetic operations.
+static Value getExtendInputZp(OpBuilder &builder, Type valueTy,
+                              FailureOr<int64_t> maybeZp, Location loc,
+                              ValueRange blockArgs) {
+  Value result;
+  // The Zp value can be either constant or non-constant, depending on
+  // whether dynamic extension is enabled.
+  // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+  // be passed as an input to linalg::GenericOp.
+  if (failed(maybeZp)) {
+    result = blockArgs[3];
+    auto zpTy = result.getType();
+    if (zpTy.getIntOrFloatBitWidth() < 32) {
+      if (zpTy.isUnsignedInteger()) {
+        result =
+            builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+      } else {
+        result =
+            builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+      }
+    }
+  } else {
+    const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
+    // Extend zeropoint for sub-32bits widths.
+    const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32;
+    result = builder.create<arith::ConstantOp>(
+        loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+  }
+  return result;
+}
+
+// Return the i32 outputZp to be used in subsequent arithmetic operations.
+static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
+                            FailureOr<int64_t> maybeZp, Location loc,
+                            ValueRange blockArgs) {
+  Value result;
+  // The Zp value can be either constant or non-constant, depending on
+  // whether dynamic extension is enabled.
+  // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+  // be passed as an input to linalg::GenericOp.
+  if (failed(maybeZp)) {
+    result = blockArgs[4];
+    auto zpTy = result.getType();
+    if (zpTy.getIntOrFloatBitWidth() < 32) {
+      if (zpTy.isUnsignedInteger()) {
+        result =
+            builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+      } else {
+        result =
+            builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+      }
+    } else if (zpTy.getIntOrFloatBitWidth() > 32) {
+      result =
+          builder.create<arith::TruncIOp>(loc, builder.getI32Type(), result);
+    }
+  } else {
+    const int32_t attrBitwidth = 32;
+    result = builder.create<arith::ConstantOp>(
+        loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+  }
+  return result;
+}
+
 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
 public:
   using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1375,41 +1555,45 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
 
     // The shift and multiplier values.
     DenseElementsAttr shiftElems;
-    if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
-      return rewriter.notifyMatchFailure(
-          op, "tosa.rescale requires constant shift input values");
+    bool isShiftConstant = false;
+    if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
+      isShiftConstant = true;
 
     DenseElementsAttr multiplierElems;
-    if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
-      return rewriter.notifyMatchFailure(
-          op, "tosa.rescale requires constant multiplier input values");
-
-    llvm::SmallVector<int8_t> shiftValues =
-        llvm::to_vector(shiftElems.getValues<int8_t>());
-    // explicit cast is required here
-    llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
-        llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
-                        [](IntegerAttr attr) -> int32_t {
-                          return static_cast<int32_t>(attr.getInt());
-                        }));
-
-    // If we shift by more than the bitwidth, this just sets to 0.
-    for (int i = 0, s = multiplierValues.size(); i < s; i++) {
-      if (shiftValues[i] > 63) {
-        shiftValues[i] = 0;
-        multiplierValues[i] = 0;
+    bool isMultiplierConstant = false;
+    if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+      isMultiplierConstant = true;
+
+    llvm::SmallVector<int8_t> shiftValues;
+    llvm::SmallVector<int32_t> multiplierValues;
+    StringAttr roundingMode;
+    bool doubleRound;
+
+    if (isMultiplierConstant && isShiftConstant) {
+      shiftValues = llvm::to_vector(shiftElems.getValues<int8_t>());
+      // explicit cast is required here
+      multiplierValues = llvm::to_vector(
+          llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
+                          [](IntegerAttr attr) -> int32_t {
+                            return static_cast<int32_t>(attr.getInt());
+                          }));
+
+      // If we shift by more than the bitwidth, this just sets to 0.
+      for (int i = 0, s = multiplierValues.size(); i < s; i++) {
+        if (shiftValues[i] > 63) {
+          shiftValues[i] = 0;
+          multiplierValues[i] = 0;
+        }
       }
-    }
-
-    // Double round only occurs if shift is greater than 31, check that this
-    // is ever true.
+      // Double round only occurs if shift is greater than 31, check that this
+      // is ever true.
+      doubleRound = op.getRoundingMode() == "DOUBLE_ROUND" &&
+                    llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+    } else
+      doubleRound = op.getRoundingMode() == "DOUBLE_ROUND";
 
-    bool doubleRound =
-        op.getRoundingMode() == "DOUBLE_ROUND" &&
-        llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
-    StringAttr roundingMode = doubleRound
-                                  ? rewriter.getStringAttr("DOUBLE_ROUND")
-                                  : rewriter.getStringAttr("SINGLE_ROUND");
+    roundingMode = doubleRound ? rewriter.getStringAttr("DOUBLE_ROUND")
+                               : rewriter.getStringAttr("SINGLE_ROUND");
 
     SmallVector<AffineMap> indexingMaps = {
         rewriter.getMultiDimIdentityMap(rank)};
@@ -1418,46 +1602,35 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     // If we are rescaling per-channel then we need to store the multiplier
     // values in a buffer.
     Value multiplierConstant;
-    int64_t multiplierArg = 0;
-    if (multiplierValues.size() == 1) {
-      multiplierConstant = arith::ConstantOp::create(
-          rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
-    } else {
-      SmallVector<AffineExpr, 2> multiplierExprs{
-          rewriter.getAffineDimExpr(rank - 1)};
-      auto multiplierType =
-          RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
-                                rewriter.getI32Type());
-      genericInputs.push_back(arith::ConstantOp::create(
-          rewriter, loc,
-          DenseIntElementsAttr::get(multiplierType, multiplierValues)));
-
-      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
-                                            /*symbolCount=*/0, multiplierExprs,
-                                            rewriter.getContext()));
-
-      multiplierArg = indexingMaps.size() - 1;
-    }
-
+    setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+        rewriter, multiplierValues, genericInputs, indexingMaps,
+        isMultiplierConstant, op, multiplierConstant);
     // If we are rescaling per-channel then we need to store the shift
     // values in a buffer.
     Value shiftConstant;
-    int64_t shiftArg = 0;
-    if (shiftValues.size() == 1) {
-      shiftConstant = arith::ConstantOp::create(
-          rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
-    } else {
-      SmallVector<AffineExpr, 2> shiftExprs = {
-          rewriter.getAffineDimExpr(rank - 1)};
-      auto shiftType =
-          RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
-                                rewriter.getIntegerType(8));
-      genericInputs.push_back(arith::ConstantOp::create(
-          rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
-      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
-                                            /*symbolCount=*/0, shiftExprs,
-                                            rewriter.getContext()));
-      shiftArg = indexingMaps.size() - 1;
+    setupLinalgGenericOpInputAndIndexingMapForShift(
+        rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
+        shiftConstant);
+
+    // broadcastMap = affine_map<(d0, d1) -> ()>
+    // It would affect as broadcast for scalar values in linalg::GenericOp.
+    AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
+    FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+    FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+    // The inputZp and outputZp may be either constant or non-constant,
+    // depending on whether dynamic extension is enabled.
+    // - If the zp is non-constant, add it as an input to linalg::GenericOp by:
+    //     1. Pushing it into 'genericInputs'.
+    //     2. Appending a corresponding affine map to 'indexingMaps'.
+    if (failed(maybeIZp)) {
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
+      indexingMaps.push_back(broadcastMap);
+    }
+    if (failed(maybeOZp)) {
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
+      indexingMaps.push_back(broadcastMap);
     }
 
     // Indexing maps for output values.
@@ -1477,40 +1650,20 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           Type valueTy = value.getType();
 
           FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
-          if (failed(maybeIZp)) {
-            (void)rewriter.notifyMatchFailure(
-                op, "input zero point cannot be statically determined");
-            return;
-          }
-
-          const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
-          // Extend zeropoint for sub-32bits widths.
-          const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
-          auto inputZp = arith::ConstantOp::create(
-              nestedBuilder, loc,
-              IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
-                               *maybeIZp));
-
+          auto inputZp = getExtendInputZp(nestedBuilder, valueTy, maybeIZp,
+                                          nestedLoc, blockArgs);
           FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
-          if (failed(maybeOZp)) {
-            (void)rewriter.notifyMatchFailure(
-                op, "output zero point cannot be statically determined");
-            return;
-          };
+          auto outputZp = getI32OutputZp(nestedBuilder, valueTy, maybeOZp,
+                                         nestedLoc, blockArgs);
 
           IntegerType outIntType =
               cast<IntegerType>(blockArgs.back().getType());
           unsigned outBitWidth = outIntType.getWidth();
-          const int32_t outAttrBitwidth = 32;
           assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
-          auto outputZp = arith::ConstantOp::create(
-              nestedBuilder, loc,
-              IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
-                               *maybeOZp));
-
-          Value multiplier = multiplierConstant ? multiplierConstant
-                                                : blockArgs[multiplierArg];
-          Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
+
+          Value multiplier =
+              multiplierConstant ? multiplierConstant : blockArgs[1];
+          Value shift = shiftConstant ? shiftConstant : blockArgs[2];
 
           if (valueTy.isUnsignedInteger()) {
             value = UnrealizedConversionCastOp::create(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index aee0caa91043d..8313173e1fec9 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1478,6 +1478,34 @@ func.func @unsupportedRescaleInexactRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>
 
 // -----
 
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: @rescale_no_const
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @rescale_no_const(%arg0 : tensor<2xi8>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) {
+  // CHECK: [[MULTIPLIER:%.+]] = tensor.collapse_shape %arg1 [] : tensor<1xi32> into tensor<i32>
+  // CHECK: [[SHIFT:%.+]] = tensor.collapse_shape %arg2 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8>
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[MULTIPLIER]], [[SHIFT]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<i32>, tensor<i8>, tensor<i8>, tensor<i8>) outs([[INIT]] : tensor<2xi8>) {
+  // CHECK:   ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8):
+  // CHECK:    [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
+  // CHECK:    [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32
+  // CHECK:    [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
+  // CHECK:    [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
+  // CHECK:    [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = "DOUBLE_ROUND"} : (i32, i32, i8) -> i32
+  // CHECK:    [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
+  // CHECK:    %c-128_i32 = arith.constant -128 : i32
+  // CHECK:    %c127_i32 = arith.constant 127 : i32
+  // CHECK:    [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32
+  // CHECK:    [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false}...
[truncated]

@lhutton1
Copy link
Contributor

Apologies for not getting around to reviewing yet @ShivaChen, I just wanted to note that this change will need rebasing on-top of #155695

@ShivaChen
Copy link
Collaborator Author

Apologies for not getting around to reviewing yet @ShivaChen, I just wanted to note that this change will need rebasing on-top of #155695

Hi @lhutton1, I have rebased on-top of #155695. Thanks for noting me the code base changes. :-)

@ShivaChen
Copy link
Collaborator Author

gentle ping

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ShivaChen, the changes look good to me overall, however, I'm not very familiar with the tosa-to-linalg lowering - perhaps @FranklandJack @GeorgeARM would be interested in taking a look?

Are you planning to contribute more ext-dyanamic support for this lowering? I'm wondering if it would be helpful to create an issue to track the progress?

@ShivaChen
Copy link
Collaborator Author

Thanks @ShivaChen, the changes look good to me overall, however, I'm not very familiar with the tosa-to-linalg lowering - perhaps @FranklandJack @GeorgeARM would be interested in taking a look?

Are you planning to contribute more ext-dyanamic support for this lowering? I'm wondering if it would be helpful to create an issue to track the progress?

Hi @lhutton1, I have created issue as #157378.
Thanks for your guidance and suggestion to create an issue to track the progress.

Hi @FranklandJack @GeorgeARM,
If it’s not too much trouble, would you interest in review? Your feedback would be greatly appreciated.

@lhutton1
Copy link
Contributor

lhutton1 commented Sep 8, 2025

Thanks for creating the issue. I was actually thinking of more of a checklist for ext-dynamic support in tosa-to-linalg. e.g. non-const mul shift was supported in #155197, this PR supports non-const shift, multiplier, input_zp, and output_zp. In the future we could look at supporting the rest of the ext-dynamic changes (https://www.mlplatform.org/tosa/tosa_spec_1_0_0.html#_ext_dynamic_extension). An issue will help to track this progress

@ShivaChen
Copy link
Collaborator Author

Thanks for creating the issue. I was actually thinking of more of a checklist for ext-dynamic support in tosa-to-linalg. e.g. non-const mul shift was supported in #155197, this PR supports non-const shift, multiplier, input_zp, and output_zp. In the future we could look at supporting the rest of the ext-dynamic changes (https://www.mlplatform.org/tosa/tosa_spec_1_0_0.html#_ext_dynamic_extension). An issue will help to track this progress

Hi @lhutton1, it seems to be a good idea. I have updated #157378. Feel free to tweak the description of the issue. Thank you very much for your guidance.

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the delay in reviewing this @ShivaChen

// 1. Pushing it into 'genericInputs'.
// 2. Appending a corresponding affine map to 'indexingMaps'.
// - If the shift is constant, set 'shiftConstant' instead.
static void setupLinalgGenericOpInputAndIndexingMapForShift(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to be mostly duplicated code to the function above with the exception of the datatype of the values passed in and the op operand (multiplier or shift). Is it possible to unify these to help reduce the duplication?

builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
}
} else if (zpTy.getIntOrFloatBitWidth() > 32) {
result =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add a test for this case?

if (zpTy.getIntOrFloatBitWidth() < 32) {
if (zpTy.isUnsignedInteger()) {
result =
builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this case was covered by the tests, is it possible to add one?

result = builder.create<arith::ConstantOp>(
loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
}
return result;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: based on the logic here, I think some of the nesting could be removed by returning result directly?

} else {
const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
// Extend zeropoint for sub-32bits widths.
const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed that this doesn't align with the pesudo-code implementation in the spec which extends the inputs and input zp to int48 (https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_rescale). This doesn't seem to have been introduced in this PR, but it's something we might want to review in the future. (cc @GeorgeARM @RoboTux)

The shift, multiplier, inputZp, and outputZp can be either constant or
non-constant, depending on whether dynamic extension is enabled.

When these values are non-constant, they are added as inputs to
linalg::GenericOp, and corresponding affine maps are appended to the
indexingMaps.
OutputZP width > 32 is not allow by SPEC
Copy link
Collaborator Author

@ShivaChen ShivaChen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks for your helpful comments to reshape the code.

@ShivaChen
Copy link
Collaborator Author

gentle ping

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for only just getting back to this @ShivaChen, LGTM!

@ShivaChen ShivaChen merged commit 107ca63 into llvm:main Oct 14, 2025
9 checks passed
@ShivaChen ShivaChen deleted the tosa-rescale-dyn branch October 14, 2025 02:10
@ShivaChen
Copy link
Collaborator Author

Apologies for only just getting back to this @ShivaChen, LGTM!

Thanks for your kind and thorough review. :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants