-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[MLIR][TOSA-Linalg] Fix rescale lowering for unsigned input zp #138313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg Author: Thomas Preud'homme (RoboTux) ChangesLowering of tosa.rescale to Linalg unconditionally sign-extend the input Full diff: https://github.com/llvm/llvm-project/pull/138313.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 95364c26d1a7d..857f2721e1328 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -82,13 +82,16 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
rhsOrResult);
}
-template <typename T>
+// Create i32 Value from zp, a zero-point value sign-extended from bitwidth.
static arith::ConstantOp
-createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
- OpBuilder &rewriter) {
- auto castedN = static_cast<T>(zp);
+createConstOpFromSExtZp(int64_t zp, unsigned zpBitwidth, unsigned attrBitwidth,
+ bool isSigned, Location loc, OpBuilder &rewriter) {
+
+ // Zero the signed-extended bits if isSigned is false.
+ zp = isSigned ? zp : zp & ((1 << zpBitwidth) - 1);
+
return rewriter.create<arith::ConstantOp>(
- op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
+ loc, IntegerAttr::get(rewriter.getIntegerType(attrBitwidth), zp));
}
static Value createLinalgBodyCalculationForElementwiseOp(
@@ -1467,11 +1470,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
Value value = blockArgs[0];
Type valueTy = value.getType();
- // For now we do all of our math in 64-bit. This is not optimal but
- // should be correct for now, consider computing correct bit depth
- // later.
- int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
-
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (failed(maybeIZp)) {
(void)rewriter.notifyMatchFailure(
@@ -1479,8 +1477,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
return;
}
- auto inputZp = createConstOpFromZpVal<int32_t>(
- op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
+ const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
+ const int32_t attrBitwidth = inBitwidth > 32 ? 48 : 32;
+ auto inputZp = createConstOpFromSExtZp(
+ *maybeIZp, inBitwidth, attrBitwidth, !op.getInputUnsigned(), loc,
nestedBuilder);
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
@@ -1490,16 +1490,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
return;
};
- // pre-process OutputZP as it can be unsigned
- auto outBitwidth = outputTy.getElementType().getIntOrFloatBitWidth();
- APInt OZp(outBitwidth, !op.getOutputUnsigned());
- OZp = static_cast<int64_t>(*maybeOZp);
- *maybeOZp = op.getOutputUnsigned()
- ? static_cast<int64_t>(OZp.getZExtValue())
- : OZp.getSExtValue();
-
- auto outputZp = createConstOpFromZpVal<int32_t>(
- op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
+ IntegerType outIntType =
+ cast<IntegerType>(blockArgs.back().getType());
+ unsigned outBitWidth = outIntType.getWidth();
+ auto outputZp = createConstOpFromSExtZp(
+ *maybeOZp, outBitWidth, /*attrBitwidth=*/32,
+ !op.getOutputUnsigned(), loc, nestedBuilder);
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
@@ -1527,10 +1523,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
// Saturate to the output size.
- IntegerType outIntType =
- cast<IntegerType>(blockArgs.back().getType());
- unsigned outBitWidth = outIntType.getWidth();
-
int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 7083d19f4372a..185f1973ecdc6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1241,10 +1241,10 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
// CHECK: [[INIT:%.+]] = tensor.empty()
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
- // CHECK: [[C17:%.+]] = arith.constant 17
+ // CHECK: [[C128:%.+]] = arith.constant 128
// CHECK: [[C22:%.+]] = arith.constant 22
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
- // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
+ // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]]
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
@@ -1255,13 +1255,45 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
// CHECK: linalg.yield [[TRUNC]]
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
return
}
+// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @rescale_i48_unsigned_output
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @rescale_i48_unsigned_output(%arg0 : tensor<2xi48>) -> () {
+ // CHECK: [[C19689:%.+]] = arith.constant 19689
+ // CHECK: [[C15:%.+]] = arith.constant 15
+ // CHECK: [[INIT:%.+]] = tensor.empty()
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
+ // CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
+ // CHECK: [[C0:%.+]] = arith.constant 0
+ // CHECK: [[C234:%.+]] = arith.constant 234
+ // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
+ // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
+ // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
+ // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
+ // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+ // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+ // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+ // CHECK: linalg.yield [[TRUNC]]
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+ %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
+
+ // CHECK: return
+ return
+}
+
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
|
@llvm/pr-subscribers-mlir Author: Thomas Preud'homme (RoboTux) ChangesLowering of tosa.rescale to Linalg unconditionally sign-extend the input Full diff: https://github.com/llvm/llvm-project/pull/138313.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 95364c26d1a7d..857f2721e1328 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -82,13 +82,16 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
rhsOrResult);
}
-template <typename T>
+// Create i32 Value from zp, a zero-point value sign-extended from bitwidth.
static arith::ConstantOp
-createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
- OpBuilder &rewriter) {
- auto castedN = static_cast<T>(zp);
+createConstOpFromSExtZp(int64_t zp, unsigned zpBitwidth, unsigned attrBitwidth,
+ bool isSigned, Location loc, OpBuilder &rewriter) {
+
+ // Zero the signed-extended bits if isSigned is false.
+ zp = isSigned ? zp : zp & ((1 << zpBitwidth) - 1);
+
return rewriter.create<arith::ConstantOp>(
- op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
+ loc, IntegerAttr::get(rewriter.getIntegerType(attrBitwidth), zp));
}
static Value createLinalgBodyCalculationForElementwiseOp(
@@ -1467,11 +1470,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
Value value = blockArgs[0];
Type valueTy = value.getType();
- // For now we do all of our math in 64-bit. This is not optimal but
- // should be correct for now, consider computing correct bit depth
- // later.
- int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
-
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (failed(maybeIZp)) {
(void)rewriter.notifyMatchFailure(
@@ -1479,8 +1477,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
return;
}
- auto inputZp = createConstOpFromZpVal<int32_t>(
- op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
+ const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
+ const int32_t attrBitwidth = inBitwidth > 32 ? 48 : 32;
+ auto inputZp = createConstOpFromSExtZp(
+ *maybeIZp, inBitwidth, attrBitwidth, !op.getInputUnsigned(), loc,
nestedBuilder);
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
@@ -1490,16 +1490,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
return;
};
- // pre-process OutputZP as it can be unsigned
- auto outBitwidth = outputTy.getElementType().getIntOrFloatBitWidth();
- APInt OZp(outBitwidth, !op.getOutputUnsigned());
- OZp = static_cast<int64_t>(*maybeOZp);
- *maybeOZp = op.getOutputUnsigned()
- ? static_cast<int64_t>(OZp.getZExtValue())
- : OZp.getSExtValue();
-
- auto outputZp = createConstOpFromZpVal<int32_t>(
- op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
+ IntegerType outIntType =
+ cast<IntegerType>(blockArgs.back().getType());
+ unsigned outBitWidth = outIntType.getWidth();
+ auto outputZp = createConstOpFromSExtZp(
+ *maybeOZp, outBitWidth, /*attrBitwidth=*/32,
+ !op.getOutputUnsigned(), loc, nestedBuilder);
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
@@ -1527,10 +1523,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
// Saturate to the output size.
- IntegerType outIntType =
- cast<IntegerType>(blockArgs.back().getType());
- unsigned outBitWidth = outIntType.getWidth();
-
int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 7083d19f4372a..185f1973ecdc6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1241,10 +1241,10 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
// CHECK: [[INIT:%.+]] = tensor.empty()
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
- // CHECK: [[C17:%.+]] = arith.constant 17
+ // CHECK: [[C128:%.+]] = arith.constant 128
// CHECK: [[C22:%.+]] = arith.constant 22
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
- // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
+ // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]]
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
@@ -1255,13 +1255,45 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
// CHECK: linalg.yield [[TRUNC]]
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
return
}
+// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @rescale_i48_unsigned_output
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @rescale_i48_unsigned_output(%arg0 : tensor<2xi48>) -> () {
+ // CHECK: [[C19689:%.+]] = arith.constant 19689
+ // CHECK: [[C15:%.+]] = arith.constant 15
+ // CHECK: [[INIT:%.+]] = tensor.empty()
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
+ // CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
+ // CHECK: [[C0:%.+]] = arith.constant 0
+ // CHECK: [[C234:%.+]] = arith.constant 234
+ // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
+ // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
+ // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
+ // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
+ // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+ // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+ // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+ // CHECK: linalg.yield [[TRUNC]]
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+ %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
+
+ // CHECK: return
+ return
+}
+
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
|
@llvm/pr-subscribers-mlir-tosa Author: Thomas Preud'homme (RoboTux) ChangesLowering of tosa.rescale to Linalg unconditionally sign-extend the input Full diff: https://github.com/llvm/llvm-project/pull/138313.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 95364c26d1a7d..857f2721e1328 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -82,13 +82,16 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
rhsOrResult);
}
-template <typename T>
+// Create i32 Value from zp, a zero-point value sign-extended from bitwidth.
static arith::ConstantOp
-createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
- OpBuilder &rewriter) {
- auto castedN = static_cast<T>(zp);
+createConstOpFromSExtZp(int64_t zp, unsigned zpBitwidth, unsigned attrBitwidth,
+ bool isSigned, Location loc, OpBuilder &rewriter) {
+
+ // Zero the signed-extended bits if isSigned is false.
+ zp = isSigned ? zp : zp & ((1 << zpBitwidth) - 1);
+
return rewriter.create<arith::ConstantOp>(
- op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
+ loc, IntegerAttr::get(rewriter.getIntegerType(attrBitwidth), zp));
}
static Value createLinalgBodyCalculationForElementwiseOp(
@@ -1467,11 +1470,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
Value value = blockArgs[0];
Type valueTy = value.getType();
- // For now we do all of our math in 64-bit. This is not optimal but
- // should be correct for now, consider computing correct bit depth
- // later.
- int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
-
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (failed(maybeIZp)) {
(void)rewriter.notifyMatchFailure(
@@ -1479,8 +1477,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
return;
}
- auto inputZp = createConstOpFromZpVal<int32_t>(
- op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
+ const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
+ const int32_t attrBitwidth = inBitwidth > 32 ? 48 : 32;
+ auto inputZp = createConstOpFromSExtZp(
+ *maybeIZp, inBitwidth, attrBitwidth, !op.getInputUnsigned(), loc,
nestedBuilder);
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
@@ -1490,16 +1490,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
return;
};
- // pre-process OutputZP as it can be unsigned
- auto outBitwidth = outputTy.getElementType().getIntOrFloatBitWidth();
- APInt OZp(outBitwidth, !op.getOutputUnsigned());
- OZp = static_cast<int64_t>(*maybeOZp);
- *maybeOZp = op.getOutputUnsigned()
- ? static_cast<int64_t>(OZp.getZExtValue())
- : OZp.getSExtValue();
-
- auto outputZp = createConstOpFromZpVal<int32_t>(
- op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
+ IntegerType outIntType =
+ cast<IntegerType>(blockArgs.back().getType());
+ unsigned outBitWidth = outIntType.getWidth();
+ auto outputZp = createConstOpFromSExtZp(
+ *maybeOZp, outBitWidth, /*attrBitwidth=*/32,
+ !op.getOutputUnsigned(), loc, nestedBuilder);
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
@@ -1527,10 +1523,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
// Saturate to the output size.
- IntegerType outIntType =
- cast<IntegerType>(blockArgs.back().getType());
- unsigned outBitWidth = outIntType.getWidth();
-
int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 7083d19f4372a..185f1973ecdc6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1241,10 +1241,10 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
// CHECK: [[INIT:%.+]] = tensor.empty()
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
- // CHECK: [[C17:%.+]] = arith.constant 17
+ // CHECK: [[C128:%.+]] = arith.constant 128
// CHECK: [[C22:%.+]] = arith.constant 22
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
- // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
+ // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]]
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
@@ -1255,13 +1255,45 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
// CHECK: linalg.yield [[TRUNC]]
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
return
}
+// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @rescale_i48_unsigned_output
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @rescale_i48_unsigned_output(%arg0 : tensor<2xi48>) -> () {
+ // CHECK: [[C19689:%.+]] = arith.constant 19689
+ // CHECK: [[C15:%.+]] = arith.constant 15
+ // CHECK: [[INIT:%.+]] = tensor.empty()
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
+ // CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
+ // CHECK: [[C0:%.+]] = arith.constant 0
+ // CHECK: [[C234:%.+]] = arith.constant 234
+ // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
+ // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
+ // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
+ // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
+ // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+ // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+ // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+ // CHECK: linalg.yield [[TRUNC]]
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+ %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
+
+ // CHECK: return
+ return
+}
+
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
|
Lowering of tosa.rescale to Linalg unconditionally sign-extend the input zero-point value, even when unsigned_input is true. This commit refactor zeropoint handling to share the same logic between input and output zeropoint.
Also clarify zeropoint extension rules.
5b201ea
to
1c36395
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. LGTM @RoboTux
I had to revert because when the zero point was 32768 in i16 case the verifier returned an error. I've created #138780 for the new revision. |
…138313) Lowering of tosa.rescale to Linalg unconditionally sign-extend the input zero-point value, even when unsigned_input is true. This commit refactor zeropoint handling to share the same logic between input and output zeropoint.
llvm#138313)" This reverts commit b67880d.
Lowering of tosa.rescale to Linalg unconditionally sign-extend the input
zero-point value, even when unsigned_input is true. This commit refactor
zeropoint handling to share the same logic between input and output
zeropoint.