Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MLIR][TOSA-Linalg] Fix rescale lowering for unsigned input zp #138313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 6, 2025

Conversation

RoboTux
Copy link
Contributor

@RoboTux RoboTux commented May 2, 2025

Lowering of tosa.rescale to Linalg unconditionally sign-extend the input
zero-point value, even when unsigned_input is true. This commit refactor
zeropoint handling to share the same logic between input and output
zeropoint.

@llvmbot
Copy link
Member

llvmbot commented May 2, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Thomas Preud'homme (RoboTux)

Changes

Lowering of tosa.rescale to Linalg unconditionally sign-extend the input
zero-point value, even when unsigned_input is true. This commit refactor
zeropoint handling to share the same logic between input and output
zeropoint.


Full diff: https://github.com/llvm/llvm-project/pull/138313.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+18-26)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+35-3)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 95364c26d1a7d..857f2721e1328 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -82,13 +82,16 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
                                           rhsOrResult);
 }
 
-template <typename T>
+// Create i32 Value from zp, a zero-point value sign-extended from bitwidth.
 static arith::ConstantOp
-createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
-                       OpBuilder &rewriter) {
-  auto castedN = static_cast<T>(zp);
+createConstOpFromSExtZp(int64_t zp, unsigned zpBitwidth, unsigned attrBitwidth,
+                        bool isSigned, Location loc, OpBuilder &rewriter) {
+
+  // Zero the signed-extended bits if isSigned is false.
+  zp = isSigned ? zp : zp & ((1 << zpBitwidth) - 1);
+
   return rewriter.create<arith::ConstantOp>(
-      op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
+      loc, IntegerAttr::get(rewriter.getIntegerType(attrBitwidth), zp));
 }
 
 static Value createLinalgBodyCalculationForElementwiseOp(
@@ -1467,11 +1470,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           Value value = blockArgs[0];
           Type valueTy = value.getType();
 
-          // For now we do all of our math in 64-bit. This is not optimal but
-          // should be correct for now, consider computing correct bit depth
-          // later.
-          int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
-
           FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
           if (failed(maybeIZp)) {
             (void)rewriter.notifyMatchFailure(
@@ -1479,8 +1477,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
             return;
           }
 
-          auto inputZp = createConstOpFromZpVal<int32_t>(
-              op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
+          const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
+          const int32_t attrBitwidth = inBitwidth > 32 ? 48 : 32;
+          auto inputZp = createConstOpFromSExtZp(
+              *maybeIZp, inBitwidth, attrBitwidth, !op.getInputUnsigned(), loc,
               nestedBuilder);
 
           FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
@@ -1490,16 +1490,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
             return;
           };
 
-          // pre-process OutputZP as it can be unsigned
-          auto outBitwidth = outputTy.getElementType().getIntOrFloatBitWidth();
-          APInt OZp(outBitwidth, !op.getOutputUnsigned());
-          OZp = static_cast<int64_t>(*maybeOZp);
-          *maybeOZp = op.getOutputUnsigned()
-                          ? static_cast<int64_t>(OZp.getZExtValue())
-                          : OZp.getSExtValue();
-
-          auto outputZp = createConstOpFromZpVal<int32_t>(
-              op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
+          IntegerType outIntType =
+              cast<IntegerType>(blockArgs.back().getType());
+          unsigned outBitWidth = outIntType.getWidth();
+          auto outputZp = createConstOpFromSExtZp(
+              *maybeOZp, outBitWidth, /*attrBitwidth=*/32,
+              !op.getOutputUnsigned(), loc, nestedBuilder);
 
           Value multiplier = multiplierConstant ? multiplierConstant
                                                 : blockArgs[multiplierArg];
@@ -1527,10 +1523,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
               nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
 
           // Saturate to the output size.
-          IntegerType outIntType =
-              cast<IntegerType>(blockArgs.back().getType());
-          unsigned outBitWidth = outIntType.getWidth();
-
           int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
           int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 7083d19f4372a..185f1973ecdc6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1241,10 +1241,10 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
   // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
-  // CHECK: [[C17:%.+]] = arith.constant 17
+  // CHECK: [[C128:%.+]] = arith.constant 128
   // CHECK: [[C22:%.+]] = arith.constant 22
   // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
-  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]]
   // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
   // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
   // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
@@ -1255,13 +1255,45 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
   %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
   %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
 
   return
 }
 
+// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @rescale_i48_unsigned_output
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @rescale_i48_unsigned_output(%arg0 : tensor<2xi48>) -> () {
+  // CHECK: [[C19689:%.+]] = arith.constant 19689
+  // CHECK: [[C15:%.+]] = arith.constant 15
+  // CHECK: [[INIT:%.+]] = tensor.empty()
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
+  // CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
+  // CHECK: [[C0:%.+]] = arith.constant 0
+  // CHECK: [[C234:%.+]] = arith.constant 234
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
+  // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
+  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
+  // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+  // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+  // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK: linalg.yield [[TRUNC]]
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
+
+  // CHECK: return
+  return
+}
+
 // -----
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>

@llvmbot
Copy link
Member

llvmbot commented May 2, 2025

@llvm/pr-subscribers-mlir

Author: Thomas Preud'homme (RoboTux)

Changes

Lowering of tosa.rescale to Linalg unconditionally sign-extend the input
zero-point value, even when unsigned_input is true. This commit refactor
zeropoint handling to share the same logic between input and output
zeropoint.


Full diff: https://github.com/llvm/llvm-project/pull/138313.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+18-26)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+35-3)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 95364c26d1a7d..857f2721e1328 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -82,13 +82,16 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
                                           rhsOrResult);
 }
 
-template <typename T>
+// Create i32 Value from zp, a zero-point value sign-extended from bitwidth.
 static arith::ConstantOp
-createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
-                       OpBuilder &rewriter) {
-  auto castedN = static_cast<T>(zp);
+createConstOpFromSExtZp(int64_t zp, unsigned zpBitwidth, unsigned attrBitwidth,
+                        bool isSigned, Location loc, OpBuilder &rewriter) {
+
+  // Zero the signed-extended bits if isSigned is false.
+  zp = isSigned ? zp : zp & ((1 << zpBitwidth) - 1);
+
   return rewriter.create<arith::ConstantOp>(
-      op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
+      loc, IntegerAttr::get(rewriter.getIntegerType(attrBitwidth), zp));
 }
 
 static Value createLinalgBodyCalculationForElementwiseOp(
@@ -1467,11 +1470,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           Value value = blockArgs[0];
           Type valueTy = value.getType();
 
-          // For now we do all of our math in 64-bit. This is not optimal but
-          // should be correct for now, consider computing correct bit depth
-          // later.
-          int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
-
           FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
           if (failed(maybeIZp)) {
             (void)rewriter.notifyMatchFailure(
@@ -1479,8 +1477,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
             return;
           }
 
-          auto inputZp = createConstOpFromZpVal<int32_t>(
-              op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
+          const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
+          const int32_t attrBitwidth = inBitwidth > 32 ? 48 : 32;
+          auto inputZp = createConstOpFromSExtZp(
+              *maybeIZp, inBitwidth, attrBitwidth, !op.getInputUnsigned(), loc,
               nestedBuilder);
 
           FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
@@ -1490,16 +1490,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
             return;
           };
 
-          // pre-process OutputZP as it can be unsigned
-          auto outBitwidth = outputTy.getElementType().getIntOrFloatBitWidth();
-          APInt OZp(outBitwidth, !op.getOutputUnsigned());
-          OZp = static_cast<int64_t>(*maybeOZp);
-          *maybeOZp = op.getOutputUnsigned()
-                          ? static_cast<int64_t>(OZp.getZExtValue())
-                          : OZp.getSExtValue();
-
-          auto outputZp = createConstOpFromZpVal<int32_t>(
-              op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
+          IntegerType outIntType =
+              cast<IntegerType>(blockArgs.back().getType());
+          unsigned outBitWidth = outIntType.getWidth();
+          auto outputZp = createConstOpFromSExtZp(
+              *maybeOZp, outBitWidth, /*attrBitwidth=*/32,
+              !op.getOutputUnsigned(), loc, nestedBuilder);
 
           Value multiplier = multiplierConstant ? multiplierConstant
                                                 : blockArgs[multiplierArg];
@@ -1527,10 +1523,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
               nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
 
           // Saturate to the output size.
-          IntegerType outIntType =
-              cast<IntegerType>(blockArgs.back().getType());
-          unsigned outBitWidth = outIntType.getWidth();
-
           int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
           int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 7083d19f4372a..185f1973ecdc6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1241,10 +1241,10 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
   // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
-  // CHECK: [[C17:%.+]] = arith.constant 17
+  // CHECK: [[C128:%.+]] = arith.constant 128
   // CHECK: [[C22:%.+]] = arith.constant 22
   // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
-  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]]
   // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
   // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
   // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
@@ -1255,13 +1255,45 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
   %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
   %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
 
   return
 }
 
+// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @rescale_i48_unsigned_output
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @rescale_i48_unsigned_output(%arg0 : tensor<2xi48>) -> () {
+  // CHECK: [[C19689:%.+]] = arith.constant 19689
+  // CHECK: [[C15:%.+]] = arith.constant 15
+  // CHECK: [[INIT:%.+]] = tensor.empty()
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
+  // CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
+  // CHECK: [[C0:%.+]] = arith.constant 0
+  // CHECK: [[C234:%.+]] = arith.constant 234
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
+  // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
+  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
+  // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+  // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+  // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK: linalg.yield [[TRUNC]]
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
+
+  // CHECK: return
+  return
+}
+
 // -----
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>

@llvmbot
Copy link
Member

llvmbot commented May 2, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Thomas Preud'homme (RoboTux)

Changes

Lowering of tosa.rescale to Linalg unconditionally sign-extend the input
zero-point value, even when unsigned_input is true. This commit refactor
zeropoint handling to share the same logic between input and output
zeropoint.


Full diff: https://github.com/llvm/llvm-project/pull/138313.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+18-26)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+35-3)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 95364c26d1a7d..857f2721e1328 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -82,13 +82,16 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
                                           rhsOrResult);
 }
 
-template <typename T>
+// Create i32 Value from zp, a zero-point value sign-extended from bitwidth.
 static arith::ConstantOp
-createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
-                       OpBuilder &rewriter) {
-  auto castedN = static_cast<T>(zp);
+createConstOpFromSExtZp(int64_t zp, unsigned zpBitwidth, unsigned attrBitwidth,
+                        bool isSigned, Location loc, OpBuilder &rewriter) {
+
+  // Zero the signed-extended bits if isSigned is false.
+  zp = isSigned ? zp : zp & ((1 << zpBitwidth) - 1);
+
   return rewriter.create<arith::ConstantOp>(
-      op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
+      loc, IntegerAttr::get(rewriter.getIntegerType(attrBitwidth), zp));
 }
 
 static Value createLinalgBodyCalculationForElementwiseOp(
@@ -1467,11 +1470,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           Value value = blockArgs[0];
           Type valueTy = value.getType();
 
-          // For now we do all of our math in 64-bit. This is not optimal but
-          // should be correct for now, consider computing correct bit depth
-          // later.
-          int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
-
           FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
           if (failed(maybeIZp)) {
             (void)rewriter.notifyMatchFailure(
@@ -1479,8 +1477,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
             return;
           }
 
-          auto inputZp = createConstOpFromZpVal<int32_t>(
-              op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
+          const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
+          const int32_t attrBitwidth = inBitwidth > 32 ? 48 : 32;
+          auto inputZp = createConstOpFromSExtZp(
+              *maybeIZp, inBitwidth, attrBitwidth, !op.getInputUnsigned(), loc,
               nestedBuilder);
 
           FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
@@ -1490,16 +1490,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
             return;
           };
 
-          // pre-process OutputZP as it can be unsigned
-          auto outBitwidth = outputTy.getElementType().getIntOrFloatBitWidth();
-          APInt OZp(outBitwidth, !op.getOutputUnsigned());
-          OZp = static_cast<int64_t>(*maybeOZp);
-          *maybeOZp = op.getOutputUnsigned()
-                          ? static_cast<int64_t>(OZp.getZExtValue())
-                          : OZp.getSExtValue();
-
-          auto outputZp = createConstOpFromZpVal<int32_t>(
-              op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
+          IntegerType outIntType =
+              cast<IntegerType>(blockArgs.back().getType());
+          unsigned outBitWidth = outIntType.getWidth();
+          auto outputZp = createConstOpFromSExtZp(
+              *maybeOZp, outBitWidth, /*attrBitwidth=*/32,
+              !op.getOutputUnsigned(), loc, nestedBuilder);
 
           Value multiplier = multiplierConstant ? multiplierConstant
                                                 : blockArgs[multiplierArg];
@@ -1527,10 +1523,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
               nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
 
           // Saturate to the output size.
-          IntegerType outIntType =
-              cast<IntegerType>(blockArgs.back().getType());
-          unsigned outBitWidth = outIntType.getWidth();
-
           int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
           int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 7083d19f4372a..185f1973ecdc6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1241,10 +1241,10 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
   // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
-  // CHECK: [[C17:%.+]] = arith.constant 17
+  // CHECK: [[C128:%.+]] = arith.constant 128
   // CHECK: [[C22:%.+]] = arith.constant 22
   // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
-  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]]
   // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
   // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
   // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
@@ -1255,13 +1255,45 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
   %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
   %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
 
   return
 }
 
+// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @rescale_i48_unsigned_output
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @rescale_i48_unsigned_output(%arg0 : tensor<2xi48>) -> () {
+  // CHECK: [[C19689:%.+]] = arith.constant 19689
+  // CHECK: [[C15:%.+]] = arith.constant 15
+  // CHECK: [[INIT:%.+]] = tensor.empty()
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
+  // CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
+  // CHECK: [[C0:%.+]] = arith.constant 0
+  // CHECK: [[C234:%.+]] = arith.constant 234
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
+  // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
+  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
+  // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+  // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+  // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK: linalg.yield [[TRUNC]]
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
+
+  // CHECK: return
+  return
+}
+
 // -----
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>

@RoboTux RoboTux requested review from GeorgeARM and banach-space May 2, 2025 19:24
Lowering of tosa.rescale to Linalg unconditionally sign-extend the input
zero-point value, even when unsigned_input is true. This commit refactor
zeropoint handling to share the same logic between input and output
zeropoint.
Also clarify zeropoint extension rules.
@RoboTux RoboTux force-pushed the fix_rescale_unsigned_zp branch from 5b201ea to 1c36395 Compare May 6, 2025 13:45
Copy link

github-actions bot commented May 6, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. LGTM @RoboTux

@lhutton1 lhutton1 merged commit b67880d into llvm:main May 6, 2025
11 checks passed
@RoboTux RoboTux deleted the fix_rescale_unsigned_zp branch May 6, 2025 22:57
RoboTux added a commit that referenced this pull request May 6, 2025
@RoboTux
Copy link
Contributor Author

RoboTux commented May 6, 2025

I had to revert because when the zero point was 32768 in i16 case the verifier returned an error. I've created #138780 for the new revision.

GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
…138313)

Lowering of tosa.rescale to Linalg unconditionally sign-extend the input
zero-point value, even when unsigned_input is true. This commit refactor
zeropoint handling to share the same logic between input and output
zeropoint.
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants