Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[mlir][tosa] Allow unsigned types for rescale ops during validation #138253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

lhutton1
Copy link
Contributor

@lhutton1 lhutton1 commented May 2, 2025

This commit allows unsigned types (ui8/ui16/ui32) when checking for valid element types, only for rescale operators.

This commit allows unsigned types (ui8/ui16/ui32) when checking
for valid element types, only for rescale operators.

Signed-off-by: Luke Hutton <[email protected]>
Change-Id: I0525c5a5542e20e832d1bf150635be7423d3799a
@lhutton1
Copy link
Contributor Author

lhutton1 commented May 2, 2025

Fixes #135699

@llvmbot
Copy link
Member

llvmbot commented May 2, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

Changes

This commit allows unsigned types (ui8/ui16/ui32) when checking for valid element types, only for rescale operators.


Full diff: https://github.com/llvm/llvm-project/pull/138253.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+17-6)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+24)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index e8b52d48347ab..feedc5057bea0 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -562,7 +562,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 
   bool CheckVariable(Operation *op);
   bool CheckVariableReadOrWrite(Operation *op);
-  bool isValidElementType(Type type);
+  bool isValidElementType(Type type, const bool allowUnsigned = false);
 
   SmallVector<
       std::function<LogicalResult(Operation *, const tosa::TargetEnv &)>>
@@ -1176,7 +1176,7 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
   return success();
 }
 
-bool TosaValidation::isValidElementType(Type type) {
+bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
   if (isa<FloatType>(type)) {
     return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType,
                Float8E5M2Type>(type);
@@ -1191,6 +1191,13 @@ bool TosaValidation::isValidElementType(Type type) {
       case 48:
         return true;
       }
+    } else if (allowUnsigned && intTy.isUnsigned()) {
+      switch (intTy.getWidth()) {
+      case 8:
+      case 16:
+      case 32:
+        return true;
+      }
     }
   } else if (mlir::isa<tosa::shapeType>(type)) {
     return true;
@@ -1209,11 +1216,15 @@ void TosaValidation::runOnOperation() {
     if (op->getDialect() != tosaDialect)
       return;
 
-    // perform valid element type check at the beginning to
-    // protect rest of code against quantized element types
+    // validate operator element types:
+    // - rescale operator is allowed to have ui8/ui16/ui32
+    //   operands/results
+    // - perform valid element type check at the beginning to
+    //   protect rest of code against quantized element types
+    const bool opIsRescale = isa<tosa::RescaleOp>(op);
     for (Value operand : op->getOperands()) {
       auto elementTy = getElementTypeOrSelf(operand);
-      if (!isValidElementType(elementTy)) {
+      if (!isValidElementType(elementTy, opIsRescale)) {
         op->emitOpError() << "is not profile-aligned: element type "
                           << elementTy << " is not legal";
         return signalPassFailure();
@@ -1221,7 +1232,7 @@ void TosaValidation::runOnOperation() {
     }
     for (Type resultTy : op->getResultTypes()) {
       auto elementTy = getElementTypeOrSelf(resultTy);
-      if (!isValidElementType(elementTy)) {
+      if (!isValidElementType(elementTy, opIsRescale)) {
         op->emitOpError() << "is not profile-aligned: element type "
                           << elementTy << " is not legal";
         return signalPassFailure();
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index c4f95b47628d1..c1f4f22887c79 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1937,3 +1937,27 @@ func.func @test_clamp_min_larger_than_max_fp32(%arg0: tensor<13x21x3xf32>) -> te
   %0 = tosa.clamp %arg0 {min_val = 2.0 : f32, max_val = -1.1: f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_rescale_input_unsigned
+func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui8>) -> (tensor<1x1xi8>) {
+  %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xi8>
+  return %r : tensor<1x1xi8>
+}
+
+// -----
+
+// CHECK-LABEL: test_rescale_output_unsigned
+func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui8>) {
+  %0 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %1 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %2 = "tosa.const"() <{values = dense<3> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %3 = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
+  return %r : tensor<1x1xui8>
+}

@tatwaichong
Copy link
Contributor

LGTM. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants