Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MLIR][TOSA-Linalg] Fix rescale lowering for unsigned input zp #138780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

RoboTux
Copy link
Contributor

@RoboTux RoboTux commented May 6, 2025

Lowering of tosa.rescale to Linalg unconditionally sign-extend the input
zero-point value, even when unsigned_input is true. This commit refactor
zeropoint handling to share the same logic between input and output
zeropoint.

Lowering of tosa.rescale to Linalg unconditionally sign-extend the input
zero-point value, even when unsigned_input is true. This commit refactor
zeropoint handling to share the same logic between input and output
zeropoint.
@llvmbot
Copy link
Member

llvmbot commented May 6, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir-linalg

Author: Thomas Preud'homme (RoboTux)

Changes

Lowering of tosa.rescale to Linalg unconditionally sign-extend the input
zero-point value, even when unsigned_input is true. This commit refactor
zeropoint handling to share the same logic between input and output
zeropoint.


Full diff: https://github.com/llvm/llvm-project/pull/138780.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+14-31)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+24-22)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+35-3)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+1-1)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 95364c26d1a7d..0b69cd2814fb9 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -82,15 +82,6 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
                                           rhsOrResult);
 }
 
-template <typename T>
-static arith::ConstantOp
-createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
-                       OpBuilder &rewriter) {
-  auto castedN = static_cast<T>(zp);
-  return rewriter.create<arith::ConstantOp>(
-      op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
-}
-
 static Value createLinalgBodyCalculationForElementwiseOp(
     Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
     ConversionPatternRewriter &rewriter) {
@@ -1467,11 +1458,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           Value value = blockArgs[0];
           Type valueTy = value.getType();
 
-          // For now we do all of our math in 64-bit. This is not optimal but
-          // should be correct for now, consider computing correct bit depth
-          // later.
-          int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
-
           FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
           if (failed(maybeIZp)) {
             (void)rewriter.notifyMatchFailure(
@@ -1479,9 +1465,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
             return;
           }
 
-          auto inputZp = createConstOpFromZpVal<int32_t>(
-              op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
-              nestedBuilder);
+          const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
+          // Extend zeropoint for sub-32bits widths.
+          const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
+          auto inputZp = nestedBuilder.create<arith::ConstantOp>(
+              loc, IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
+                                    *maybeIZp));
 
           FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
           if (failed(maybeOZp)) {
@@ -1490,16 +1479,14 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
             return;
           };
 
-          // pre-process OutputZP as it can be unsigned
-          auto outBitwidth = outputTy.getElementType().getIntOrFloatBitWidth();
-          APInt OZp(outBitwidth, !op.getOutputUnsigned());
-          OZp = static_cast<int64_t>(*maybeOZp);
-          *maybeOZp = op.getOutputUnsigned()
-                          ? static_cast<int64_t>(OZp.getZExtValue())
-                          : OZp.getSExtValue();
-
-          auto outputZp = createConstOpFromZpVal<int32_t>(
-              op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
+          IntegerType outIntType =
+              cast<IntegerType>(blockArgs.back().getType());
+          unsigned outBitWidth = outIntType.getWidth();
+          const int32_t outAttrBitwidth = 32;
+          assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
+          auto outputZp = nestedBuilder.create<arith::ConstantOp>(
+              loc, IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
+                                    *maybeOZp));
 
           Value multiplier = multiplierConstant ? multiplierConstant
                                                 : blockArgs[multiplierArg];
@@ -1527,10 +1514,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
               nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
 
           // Saturate to the output size.
-          IntegerType outIntType =
-              cast<IntegerType>(blockArgs.back().getType());
-          unsigned outBitWidth = outIntType.getWidth();
-
           int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
           int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index de06b621cbe3d..371c6dc27b428 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2118,7 +2118,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
 // return failure if val is not a constant
 // set zp to -1 if val is non-zero float or val is not integer nor float
 // otherwise set zp to val's constant value
-static FailureOr<int64_t> getZeroPoint(Value val) {
+static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
   ElementsAttr zpAttr;
   if (!matchPattern(val, m_Constant(&zpAttr))) {
     return failure();
@@ -2135,7 +2135,10 @@ static FailureOr<int64_t> getZeroPoint(Value val) {
   }
 
   if (llvm::isa<IntegerType>(zpElemType)) {
-    return zpAttr.getValues<APInt>()[0].getSExtValue();
+    if (signExtend)
+      return zpAttr.getValues<APInt>()[0].getSExtValue();
+    else
+      return zpAttr.getValues<APInt>()[0].getZExtValue();
   }
 
   // return non-zero value to trigger error check
@@ -2175,8 +2178,7 @@ static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
       return op.emitOpError()
              << "expect " << tensorName << "_zp of 0, got " << zp;
     }
-    if (zpElemType.isInteger(16) && tensorUnsigned &&
-        zp != static_cast<int16_t>(32768)) {
+    if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
       return op.emitOpError() << "expect " << tensorName
                               << "_zp of 0 or 32768 for unsigned int16 "
                               << tensorName << ", got " << zp;
@@ -2186,30 +2188,30 @@ static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
   return success();
 }
 
-#define ZERO_POINT_HELPER(OP, OPERAND_NAME)                                    \
+#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)                       \
   FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() {                \
-    return getZeroPoint(get##OPERAND_NAME##Zp());                              \
+    return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND);                 \
   }                                                                            \
   LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) {        \
     return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
   }
 
-ZERO_POINT_HELPER(Conv2DOp, Input)
-ZERO_POINT_HELPER(Conv2DOp, Weight)
-ZERO_POINT_HELPER(Conv3DOp, Input)
-ZERO_POINT_HELPER(Conv3DOp, Weight)
-ZERO_POINT_HELPER(DepthwiseConv2DOp, Input)
-ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight)
-ZERO_POINT_HELPER(TransposeConv2DOp, Input)
-ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
-ZERO_POINT_HELPER(AvgPool2dOp, Input)
-ZERO_POINT_HELPER(AvgPool2dOp, Output)
-ZERO_POINT_HELPER(MatMulOp, A)
-ZERO_POINT_HELPER(MatMulOp, B)
-ZERO_POINT_HELPER(NegateOp, Input1)
-ZERO_POINT_HELPER(NegateOp, Output)
-ZERO_POINT_HELPER(RescaleOp, Input)
-ZERO_POINT_HELPER(RescaleOp, Output)
+ZERO_POINT_HELPER(Conv2DOp, Input, true)
+ZERO_POINT_HELPER(Conv2DOp, Weight, true)
+ZERO_POINT_HELPER(Conv3DOp, Input, true)
+ZERO_POINT_HELPER(Conv3DOp, Weight, true)
+ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
+ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
+ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
+ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
+ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
+ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
+ZERO_POINT_HELPER(MatMulOp, A, true)
+ZERO_POINT_HELPER(MatMulOp, B, true)
+ZERO_POINT_HELPER(NegateOp, Input1, true)
+ZERO_POINT_HELPER(NegateOp, Output, true)
+ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
+ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
 #undef ZERO_POINT_HELPER
 
 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 7083d19f4372a..185f1973ecdc6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1241,10 +1241,10 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
   // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
-  // CHECK: [[C17:%.+]] = arith.constant 17
+  // CHECK: [[C128:%.+]] = arith.constant 128
   // CHECK: [[C22:%.+]] = arith.constant 22
   // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
-  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]]
   // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
   // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
   // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
@@ -1255,13 +1255,45 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
   %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
   %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
 
   return
 }
 
+// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @rescale_i48_unsigned_output
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @rescale_i48_unsigned_output(%arg0 : tensor<2xi48>) -> () {
+  // CHECK: [[C19689:%.+]] = arith.constant 19689
+  // CHECK: [[C15:%.+]] = arith.constant 15
+  // CHECK: [[INIT:%.+]] = tensor.empty()
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
+  // CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
+  // CHECK: [[C0:%.+]] = arith.constant 0
+  // CHECK: [[C234:%.+]] = arith.constant 234
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
+  // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
+  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
+  // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+  // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+  // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK: linalg.yield [[TRUNC]]
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
+
+  // CHECK: return
+  return
+}
+
 // -----
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 9ccb310c4491d..56d76585be71b 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1517,7 +1517,7 @@ func.func @test_rescale_invalid_output_zp_u16(%arg0: tensor<13x21x3xi16>) -> ten
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
   %output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi16>} : () -> tensor<1xi16>
-  // expected-error@+1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got -1}}
+  // expected-error@+1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got 65535}}
   %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }

@llvmbot
Copy link
Member

llvmbot commented May 6, 2025

@llvm/pr-subscribers-mlir

Author: Thomas Preud'homme (RoboTux)

Changes

Lowering of tosa.rescale to Linalg unconditionally sign-extend the input
zero-point value, even when unsigned_input is true. This commit refactor
zeropoint handling to share the same logic between input and output
zeropoint.


Full diff: https://github.com/llvm/llvm-project/pull/138780.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+14-31)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+24-22)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+35-3)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+1-1)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 95364c26d1a7d..0b69cd2814fb9 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -82,15 +82,6 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
                                           rhsOrResult);
 }
 
-template <typename T>
-static arith::ConstantOp
-createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
-                       OpBuilder &rewriter) {
-  auto castedN = static_cast<T>(zp);
-  return rewriter.create<arith::ConstantOp>(
-      op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
-}
-
 static Value createLinalgBodyCalculationForElementwiseOp(
     Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
     ConversionPatternRewriter &rewriter) {
@@ -1467,11 +1458,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           Value value = blockArgs[0];
           Type valueTy = value.getType();
 
-          // For now we do all of our math in 64-bit. This is not optimal but
-          // should be correct for now, consider computing correct bit depth
-          // later.
-          int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
-
           FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
           if (failed(maybeIZp)) {
             (void)rewriter.notifyMatchFailure(
@@ -1479,9 +1465,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
             return;
           }
 
-          auto inputZp = createConstOpFromZpVal<int32_t>(
-              op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
-              nestedBuilder);
+          const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
+          // Extend zeropoint for sub-32bits widths.
+          const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
+          auto inputZp = nestedBuilder.create<arith::ConstantOp>(
+              loc, IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
+                                    *maybeIZp));
 
           FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
           if (failed(maybeOZp)) {
@@ -1490,16 +1479,14 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
             return;
           };
 
-          // pre-process OutputZP as it can be unsigned
-          auto outBitwidth = outputTy.getElementType().getIntOrFloatBitWidth();
-          APInt OZp(outBitwidth, !op.getOutputUnsigned());
-          OZp = static_cast<int64_t>(*maybeOZp);
-          *maybeOZp = op.getOutputUnsigned()
-                          ? static_cast<int64_t>(OZp.getZExtValue())
-                          : OZp.getSExtValue();
-
-          auto outputZp = createConstOpFromZpVal<int32_t>(
-              op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
+          IntegerType outIntType =
+              cast<IntegerType>(blockArgs.back().getType());
+          unsigned outBitWidth = outIntType.getWidth();
+          const int32_t outAttrBitwidth = 32;
+          assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
+          auto outputZp = nestedBuilder.create<arith::ConstantOp>(
+              loc, IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
+                                    *maybeOZp));
 
           Value multiplier = multiplierConstant ? multiplierConstant
                                                 : blockArgs[multiplierArg];
@@ -1527,10 +1514,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
               nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
 
           // Saturate to the output size.
-          IntegerType outIntType =
-              cast<IntegerType>(blockArgs.back().getType());
-          unsigned outBitWidth = outIntType.getWidth();
-
           int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
           int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index de06b621cbe3d..371c6dc27b428 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2118,7 +2118,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
 // return failure if val is not a constant
 // set zp to -1 if val is non-zero float or val is not integer nor float
 // otherwise set zp to val's constant value
-static FailureOr<int64_t> getZeroPoint(Value val) {
+static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
   ElementsAttr zpAttr;
   if (!matchPattern(val, m_Constant(&zpAttr))) {
     return failure();
@@ -2135,7 +2135,10 @@ static FailureOr<int64_t> getZeroPoint(Value val) {
   }
 
   if (llvm::isa<IntegerType>(zpElemType)) {
-    return zpAttr.getValues<APInt>()[0].getSExtValue();
+    if (signExtend)
+      return zpAttr.getValues<APInt>()[0].getSExtValue();
+    else
+      return zpAttr.getValues<APInt>()[0].getZExtValue();
   }
 
   // return non-zero value to trigger error check
@@ -2175,8 +2178,7 @@ static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
       return op.emitOpError()
              << "expect " << tensorName << "_zp of 0, got " << zp;
     }
-    if (zpElemType.isInteger(16) && tensorUnsigned &&
-        zp != static_cast<int16_t>(32768)) {
+    if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
       return op.emitOpError() << "expect " << tensorName
                               << "_zp of 0 or 32768 for unsigned int16 "
                               << tensorName << ", got " << zp;
@@ -2186,30 +2188,30 @@ static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
   return success();
 }
 
-#define ZERO_POINT_HELPER(OP, OPERAND_NAME)                                    \
+#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)                       \
   FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() {                \
-    return getZeroPoint(get##OPERAND_NAME##Zp());                              \
+    return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND);                 \
   }                                                                            \
   LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) {        \
     return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
   }
 
-ZERO_POINT_HELPER(Conv2DOp, Input)
-ZERO_POINT_HELPER(Conv2DOp, Weight)
-ZERO_POINT_HELPER(Conv3DOp, Input)
-ZERO_POINT_HELPER(Conv3DOp, Weight)
-ZERO_POINT_HELPER(DepthwiseConv2DOp, Input)
-ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight)
-ZERO_POINT_HELPER(TransposeConv2DOp, Input)
-ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
-ZERO_POINT_HELPER(AvgPool2dOp, Input)
-ZERO_POINT_HELPER(AvgPool2dOp, Output)
-ZERO_POINT_HELPER(MatMulOp, A)
-ZERO_POINT_HELPER(MatMulOp, B)
-ZERO_POINT_HELPER(NegateOp, Input1)
-ZERO_POINT_HELPER(NegateOp, Output)
-ZERO_POINT_HELPER(RescaleOp, Input)
-ZERO_POINT_HELPER(RescaleOp, Output)
+ZERO_POINT_HELPER(Conv2DOp, Input, true)
+ZERO_POINT_HELPER(Conv2DOp, Weight, true)
+ZERO_POINT_HELPER(Conv3DOp, Input, true)
+ZERO_POINT_HELPER(Conv3DOp, Weight, true)
+ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
+ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
+ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
+ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
+ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
+ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
+ZERO_POINT_HELPER(MatMulOp, A, true)
+ZERO_POINT_HELPER(MatMulOp, B, true)
+ZERO_POINT_HELPER(NegateOp, Input1, true)
+ZERO_POINT_HELPER(NegateOp, Output, true)
+ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
+ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
 #undef ZERO_POINT_HELPER
 
 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 7083d19f4372a..185f1973ecdc6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1241,10 +1241,10 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
   // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
-  // CHECK: [[C17:%.+]] = arith.constant 17
+  // CHECK: [[C128:%.+]] = arith.constant 128
   // CHECK: [[C22:%.+]] = arith.constant 22
   // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
-  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]]
   // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
   // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
   // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
@@ -1255,13 +1255,45 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
   %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
   %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
 
   return
 }
 
+// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @rescale_i48_unsigned_output
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @rescale_i48_unsigned_output(%arg0 : tensor<2xi48>) -> () {
+  // CHECK: [[C19689:%.+]] = arith.constant 19689
+  // CHECK: [[C15:%.+]] = arith.constant 15
+  // CHECK: [[INIT:%.+]] = tensor.empty()
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
+  // CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
+  // CHECK: [[C0:%.+]] = arith.constant 0
+  // CHECK: [[C234:%.+]] = arith.constant 234
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
+  // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
+  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
+  // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+  // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+  // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK: linalg.yield [[TRUNC]]
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
+
+  // CHECK: return
+  return
+}
+
 // -----
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 9ccb310c4491d..56d76585be71b 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1517,7 +1517,7 @@ func.func @test_rescale_invalid_output_zp_u16(%arg0: tensor<13x21x3xi16>) -> ten
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
   %output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi16>} : () -> tensor<1xi16>
-  // expected-error@+1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got -1}}
+  // expected-error@+1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got 65535}}
   %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }

@lhutton1
Copy link
Contributor

lhutton1 commented May 7, 2025

Thanks for the quick fix, is it possible to add a test for the failure described in: #138313 (comment)?

@RoboTux
Copy link
Contributor Author

RoboTux commented May 7, 2025

Thanks for the quick fix, is it possible to add a test for the failure described in: #138313 (comment)?

Done. I've checked that the test fails without the fix and obviously passes with the fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants