Fix mixed precision operands in splitReduction pass#22138
Fix mixed precision operands in splitReduction pass#22138hanhanW merged 2 commits intoiree-org:mainfrom
Conversation
|
@bangtianliu please take a look. @FlintWangacc thanks for the fix, couple of updates
|
Sure will do! |
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/SplitReduction.cpp
Outdated
Show resolved
Hide resolved
|
I don't see any fp16/fp32 precision mixing in the MLIR you posted above. Could you please correct/clarify this? |
This issue is found in compile deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B model. |
This is before splitReduction of argmax. %1204:2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%extracted_slice : tensor<151936xf32>) outs(%1201, %1203 : tensor<f32>, tensor<i64>) {
^bb0(%in: f32, %out: f32, %out_574: i64):
%1292 = arith.truncf %in : f32 to f16
%1293 = arith.extf %1292 : f16 to f32
%1294 = linalg.index 0 : index
%1295 = arith.index_cast %1294 : index to i64
%1296 = arith.maximumf %1293, %out : f32
%1297 = arith.cmpf ogt, %1293, %out : f32
%1298 = arith.select %1297, %1295, %out_574 : i64
linalg.yield %1296, %1298 : f32, i64
} -> (tensor<f32>, tensor<i64>)This is after %2731:2 = "linalg.generic"(%2724, %2727, %2730) <{indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>], iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>], operandSegmentSizes = array<i32: 1, 2>}> ({
^bb0(%arg7: f16, %arg8: f32, %arg9: i64):
%2856 = "linalg.index"() <{dim = 1 : i64}> : () -> index
%2857 = "arith.index_cast"(%2856) : (index) -> i64
%2858 = "arith.maximumf"(%arg7, %arg8) <{fastmath = #arith.fastmath<none>}> : (f16, f32) -> f16
%2859 = "arith.cmpf"(%arg7, %arg8) <{fastmath = #arith.fastmath<none>, predicate = 2 : i64}> : (f16, f32) -> i1
%2860 = "arith.select"(%2859, %2857, %arg9) : (i1, i64, i64) -> i64
"linalg.yield"(%2858, %2860) : (f16, i64) -> ()
}) : (tensor<1187x128xf16>, tensor<1187xf32>, tensor<1187xi64>) -> (tensor<1187xf32>, tensor<1187xi64>)This line %2859 = "arith.cmpf"(%arg7, %arg8) <{fastmath = #arith.fastmath<none>, predicate = 2 : i64}> : (f16, f32) -> i1show the mixture |
4836725 to
546ac6c
Compare
18c9010 to
6d6a561
Compare
|
As what Mahesh commented above, please add a test. |
hanhanW
left a comment
There was a problem hiding this comment.
Please improve the PR title and body like https://google.github.io/eng-practices/review/developer/cl-descriptions.html
| if (outIdx.getType() != reductionIdx.getType()) | ||
| reductionIdx = arith::IndexCastOp::create(b, loc, outIdx.getType(), | ||
| reductionIdx); | ||
| Value maxVal = arith::MaximumFOp::create(b, loc, in, outVal); | ||
| Value cmp = arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OGT, in, | ||
| outVal); | ||
|
|
||
| reductionIdx = | ||
| b.create<arith::IndexCastOp>(loc, outIdx.getType(), reductionIdx); |
There was a problem hiding this comment.
Please delete the blank line and wrap the statement with braces.
| reductionIdx = | ||
| b.create<arith::IndexCastOp>(loc, outIdx.getType(), reductionIdx); | ||
|
|
||
| // Cast the f16 input to f32 to match the output value type. |
There was a problem hiding this comment.
The comment mismatches the code. There are no f16 and f32 context at all. Please update the comment.
btw, I think you need TrucF for the inverse cases.
There was a problem hiding this comment.
I use TruncFOp for the inverse case.
Fix the following error iree-org#22138 ``` error: 'arith.maximumf' op requires same type for all operands and results ``` Signed-off-by: Zhengbo Wang <[email protected]>
fa5c28a to
0a98684
Compare
I've added a test case argmax_no_mixture in |
|
|
||
| util.func public @argmax_no_mixture(%arg3: tensor<1x151936xf16>) -> tensor<1x1xi64> { |
There was a problem hiding this comment.
This needs a separator.
| util.func public @argmax_no_mixture(%arg3: tensor<1x151936xf16>) -> tensor<1x1xi64> { | |
| // ----- | |
| util.func public @argmax_no_mixture(%arg3: tensor<1x151936xf16>) -> tensor<1x1xi64> { |
| util.return %expanded_1079 : tensor<1x1xi64> | ||
| } | ||
|
|
||
| // CHECK-LABEL: util.func public @argmax_no_mixture |
There was a problem hiding this comment.
Please add some additional checks to verify that the appropriate extf or truncf ops are being created.
| if (in.getType() != outVal.getType()) { | ||
| if (outVal.getType().getIntOrFloatBitWidth() > | ||
| in.getType().getIntOrFloatBitWidth()) { | ||
| inCast = b.create<arith::ExtFOp>(loc, outVal.getType(), in); | ||
| } else { | ||
| inCast = b.create<arith::TruncFOp>(loc, outVal.getType(), in); | ||
| } | ||
| } |
There was a problem hiding this comment.
nit: This could just be:
| if (in.getType() != outVal.getType()) { | |
| if (outVal.getType().getIntOrFloatBitWidth() > | |
| in.getType().getIntOrFloatBitWidth()) { | |
| inCast = b.create<arith::ExtFOp>(loc, outVal.getType(), in); | |
| } else { | |
| inCast = b.create<arith::TruncFOp>(loc, outVal.getType(), in); | |
| } | |
| } | |
| if (outVal.getType().getIntOrFloatBitWidth() > | |
| in.getType().getIntOrFloatBitWidth()) { | |
| inCast = b.create<arith::ExtFOp>(loc, outVal.getType(), in); | |
| } else if (outVal.getType().getIntOrFloatBitWidth() < | |
| in.getType().getIntOrFloatBitWidth()) { | |
| inCast = b.create<arith::TruncFOp>(loc, outVal.getType(), in); | |
| } |
| if (outVal.getType().getIntOrFloatBitWidth() > | ||
| in.getType().getIntOrFloatBitWidth()) { | ||
| inCast = b.create<arith::ExtFOp>(loc, outVal.getType(), in); | ||
| } else { | ||
| inCast = b.create<arith::TruncFOp>(loc, outVal.getType(), in); | ||
| } |
There was a problem hiding this comment.
There are two cases, but we only have one test. Can you add a test for the other one?
There was a problem hiding this comment.
Thanks for your comment. I have made some change according to your comment.
Fix the following error iree-org#22138 ``` error: 'arith.maximumf' op requires same type for all operands and results ``` Signed-off-by: Zhengbo Wang <[email protected]>
38d0650 to
2c6e04c
Compare
Max191
left a comment
There was a problem hiding this comment.
The PR description shows a linalg.generic with an arith.truncf followed by an arith.extf, but there is no test for that case. Can you add one that matches the failure you saw?
8eb92f7 to
d1df97d
Compare
…org#22138) The splitReduction pass was failing when argmax operations had mixed precision operands (e.g., f16 input with f32 accumulator). The pass would create `arith.maximumf` and `arith.cmpf` operations with mismatched operand types, causing verification errors. This change ensures proper type casting between input values and output accumulator values: - Add explicit type casting using `arith.extf` when the accumulator has higher precision than the input - Add explicit type casting using `arith.truncf` when the accumulator has lower precision than the input The fix handles all mixed precision scenarios while preserving the original behavior for matching types. Added test cases demonstrate the fix for both f16->f32 (extension) and f64->f32 (truncation) scenarios. Signed-off-by: Zhengbo Wang <[email protected]>
d1df97d to
c713085
Compare
…org#22138) The splitReduction pass was failing when argmax operations had mixed precision operands (e.g., f16 input with f32 accumulator). The pass would create `arith.maximumf` and `arith.cmpf` operations with mismatched operand types, causing verification errors. This change ensures proper type casting between input values and output accumulator values: - Add explicit type casting using `arith.extf` when the accumulator has higher precision than the input - Add explicit type casting using `arith.truncf` when the accumulator has lower precision than the input The fix handles all mixed precision scenarios while preserving the original behavior for matching types. Added test cases demonstrate the fix for both f16->f32 (extension) and f64->f32 (truncation) scenarios. Signed-off-by: Zhengbo Wang <[email protected]>
c713085 to
144c2cb
Compare
…org#22138) The splitReduction pass was failing when argmax operations had mixed precision operands (e.g., f16 input with f32 accumulator). The pass would create `arith.maximumf` and `arith.cmpf` operations with mismatched operand types, causing verification errors. This change ensures proper type casting between input values and output accumulator values: - Add explicit type casting using `arith.extf` when the accumulator has higher precision than the input - Add explicit type casting using `arith.truncf` when the accumulator has lower precision than the input The fix handles all mixed precision scenarios while preserving the original behavior for matching types. Added test cases demonstrate the fix for both f16->f32 (extension) and f64->f32 (truncation) scenarios. Signed-off-by: Zhengbo Wang <[email protected]>
144c2cb to
06acc19
Compare
|
@FlintWangacc can you fix pre-commit issue: https://github.com/iree-org/iree/actions/runs/19166008772/job/54820321554?pr=22138 ? Otherwise, I can't land the PR for you. |
…org#22138) The splitReduction pass was failing when argmax operations had mixed precision operands (e.g., f16 input with f32 accumulator). The pass would create `arith.maximumf` and `arith.cmpf` operations with mismatched operand types, causing verification errors. This change ensures proper type casting between input values and output accumulator values: - Add explicit type casting using `arith.extf` when the accumulator has higher precision than the input - Add explicit type casting using `arith.truncf` when the accumulator has lower precision than the input The fix handles all mixed precision scenarios while preserving the original behavior for matching types. Added test cases demonstrate the fix for both f16->f32 (extension) and f64->f32 (truncation) scenarios. Signed-off-by: Zhengbo Wang <[email protected]>
726e749 to
08eb7d7
Compare
Thanks for your reminder. I have fix the clang-format issue. |
compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/SplitReduction.cpp
Show resolved
Hide resolved
| // ----- | ||
|
|
||
| util.func public @argmax_extf(%arg0: tensor<1x151936xf16>) -> tensor<1x1xi64> { | ||
| %collapsed_967 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<1x151936xf16> into tensor<151936xf16> |
There was a problem hiding this comment.
you can simplify the test by using function argument directly like
util.func public @argmax_extf(
%input: tensor<151936xf16>,
%init_val: tensor<f32>,
%init_idx: tensor<i64>) -> (tensor<f32>, tensor<i64>) {
%result:2 = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> ()>,
affine_map<(d0) -> ()>],
iterator_types = ["reduction"]}
ins(%input : tensor<151936xf16>)
outs(%init_val, %init_idx : tensor<f32>, tensor<i64>) {
^bb0(%in: f16, %out: f32, %out_idx: i64):
%ext = arith.extf %in : f16 to f32
%i = linalg.index 0 : index
%i_cast = arith.index_cast %i : index to i64
%max = arith.maximumf %ext, %out : f32
%cmp = arith.cmpf ogt, %ext, %out : f32
%sel = arith.select %cmp, %i_cast, %out_idx : i64
linalg.yield %max, %sel : f32, i64
} -> (tensor<f32>, tensor<i64>)
util.return %result#0, %result#1 : tensor<f32>, tensor<i64>
}
There was a problem hiding this comment.
I did some tests, and this MLIR does not cause the program to crash.
There was a problem hiding this comment.
I just used this example mlir to show the idea how the tests can be further simplified.
Maybe the initial values are still needed there. But do we really need tensor.collapse_shape.
There was a problem hiding this comment.
I will try to simplify the MLIR program.
There was a problem hiding this comment.
I have simplified the MLIR.
| %c0_i64 = arith.constant 0 : i64 | ||
| %cst = arith.constant 0xFF800000 : f32 | ||
| %0 = tensor.empty() : tensor<i64> | ||
| %1 = linalg.fill ins(%c0_i64 : i64) outs(%0 : tensor<i64>) -> tensor<i64> | ||
| %2 = tensor.empty() : tensor<f32> | ||
| %3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<f32>) -> tensor<f32> | ||
| %1236:2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%collapsed_967 : tensor<151936xf64>) outs(%3, %1 : tensor<f32>, tensor<i64>) { |
…org#22138) The splitReduction pass was failing when argmax operations had mixed precision operands (e.g., f16 input with f32 accumulator). The pass would create `arith.maximumf` and `arith.cmpf` operations with mismatched operand types, causing verification errors. This change ensures proper type casting between input values and output accumulator values: - Add explicit type casting using `arith.extf` when the accumulator has higher precision than the input - Add explicit type casting using `arith.truncf` when the accumulator has lower precision than the input The fix handles all mixed precision scenarios while preserving the original behavior for matching types. Added test cases demonstrate the fix for both f16->f32 (extension) and f64->f32 (truncation) scenarios. Signed-off-by: Zhengbo Wang <[email protected]>
be41e94 to
b8595e2
Compare
…org#22138) The splitReduction pass was failing when argmax operations had mixed precision operands (e.g., f16 input with f32 accumulator). The pass would create `arith.maximumf` and `arith.cmpf` operations with mismatched operand types, causing verification errors. This change ensures proper type casting between input values and output accumulator values: - Add explicit type casting using `arith.extf` when the accumulator has higher precision than the input - Add explicit type casting using `arith.truncf` when the accumulator has lower precision than the input The fix handles all mixed precision scenarios while preserving the original behavior for matching types. Added test cases demonstrate the fix for both f16->f32 (extension) and f64->f32 (truncation) scenarios. Signed-off-by: Zhengbo Wang <[email protected]>
a5be822 to
6cd487c
Compare
…org#22138) The splitReduction pass was failing when argmax operations had mixed precision operands (e.g., f16 input with f32 accumulator). The pass would create `arith.maximumf` and `arith.cmpf` operations with mismatched operand types, causing verification errors. This change ensures proper type casting between input values and output accumulator values: - Add explicit type casting using `arith.extf` when the accumulator has higher precision than the input - Add explicit type casting using `arith.truncf` when the accumulator has lower precision than the input The fix handles all mixed precision scenarios while preserving the original behavior for matching types. Added test cases demonstrate the fix for both f16->f32 (extension) and f64->f32 (truncation) scenarios. Signed-off-by: Zhengbo Wang <[email protected]>
d368820 to
27b25fa
Compare
…org#22138) The splitReduction pass was failing when argmax operations had mixed precision operands (e.g., f16 input with f32 accumulator). The pass would create `arith.maximumf` and `arith.cmpf` operations with mismatched operand types, causing verification errors. This change ensures proper type casting between input values and output accumulator values: - Add explicit type casting using `arith.extf` when the accumulator has higher precision than the input - Add explicit type casting using `arith.truncf` when the accumulator has lower precision than the input The fix handles all mixed precision scenarios while preserving the original behavior for matching types. Added test cases demonstrate the fix for both f16->f32 (extension) and f64->f32 (truncation) scenarios. Signed-off-by: Zhengbo Wang <[email protected]>
fb1e1a2 to
1db1561
Compare
[Deepseek_R1_Distill_Qwen_1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) has issues for fp16. I use iree-compile to compile fp16 version of Deepseek_R1_Distill_Qwen_1.5B. It report the following error. ```bash DeepSeek_R1_Distill_Qwen_1.5B_fp16.mlir:27492:13: error: 'arith.maximumf' op requires the same type for all operands and results %5286 = torch.aten.argmax %5285, %int1_8048, %false_8049 : !torch.vtensor<[1,151936],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64> ^ DeepSeek_R1_Distill_Qwen_1.5B_fp16.mlir:13391:15: note: called from %169:57 = call @forward(%112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127, %128, %129, %130, %131, %132, %133, %134, %135, %136, %137, %138, %139, %140, %141, %142, %143, %144, %145, %146, %147, %148, %149, %150, %151, %152, %153, %154, %155, %156, %157, %158, %159, %160, %161, %162, %163, %164, %165, %166, %167, %168) : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>) -> (!torch.vtensor<[1,1],si64>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>) ^ DeepSeek_R1_Distill_Qwen_1.5B_fp16.mlir:27492:13: note: see current operation: %2858 = "arith.maximumf"(%arg7, %arg8) <{fastmath = #arith.fastmath<none>}> : (f16, f32) -> f16 %5286 = torch.aten.argmax %5285, %int1_8048, %false_8049 : !torch.vtensor<[1,151936],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64> ``` before splitReduction ```bash %1209:2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%extracted_slice : tensor<151936xf32>) outs(%1201, %1203 : tensor<f32>, tensor<i64>) { ^bb0(%in: f32, %out: f32, %out_577: i64): %1297 = arith.truncf %in : f32 to f16 %1298 = arith.extf %1297 : f16 to f32 %1299 = linalg.index 0 : index %1300 = arith.index_cast %1299 : index to i64 %1301 = arith.maximumf %1298, %out : f32 %1302 = arith.cmpf ogt, %1298, %out : f32 %1303 = arith.select %1302, %1300, %out_577 : i64 linalg.yield %1301, %1303 : f32, i64 } -> (tensor<f32>, tensor<i64>) ``` after splitReduction ```bash %1209:2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%1208#0, %1208#1 : tensor<1187xf32>, tensor<1187xi64>) outs(%1201, %1203 : tensor<f32>, tensor<i64>) { ^bb0(%in: f32, %in_577: i64, %out: f32, %out_578: i64): %1297 = linalg.index 0 : index %1298 = arith.muli %1297, %c128 : index %1299 = arith.index_cast %1298 : index to i64 %1300 = arith.addi %1299, %in_577 : i64 %1301 = arith.maximumf %in, %out : f32 %1302 = arith.cmpf ogt, %in, %out : f32 %1303 = arith.select %1302, %1300, %out_578 : i64 linalg.yield %1301, %1303 : f32, i64 } -> (tensor<f32>, tensor<i64>) ``` This will cause iree-compile crash in the following check of mlir file because the mix of fp16 and fp32 precision. Signed-off-by: Zhengbo Wang <[email protected]>
[Deepseek_R1_Distill_Qwen_1.5B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) has issues for fp16. I use iree-compile to compile fp16 version of Deepseek_R1_Distill_Qwen_1.5B. It report the following error. ```bash DeepSeek_R1_Distill_Qwen_1.5B_fp16.mlir:27492:13: error: 'arith.maximumf' op requires the same type for all operands and results %5286 = torch.aten.argmax %5285, %int1_8048, %false_8049 : !torch.vtensor<[1,151936],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64> ^ DeepSeek_R1_Distill_Qwen_1.5B_fp16.mlir:13391:15: note: called from %169:57 = call @forward(%112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127, %128, %129, %130, %131, %132, %133, %134, %135, %136, %137, %138, %139, %140, %141, %142, %143, %144, %145, %146, %147, %148, %149, %150, %151, %152, %153, %154, %155, %156, %157, %158, %159, %160, %161, %162, %163, %164, %165, %166, %167, %168) : (!torch.vtensor<[1,1],si64>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>, !torch.vtensor<[1,?,2,128],f16>) -> (!torch.vtensor<[1,1],si64>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>, !torch.vtensor<[1,1,2,128],f16>) ^ DeepSeek_R1_Distill_Qwen_1.5B_fp16.mlir:27492:13: note: see current operation: %2858 = "arith.maximumf"(%arg7, %arg8) <{fastmath = #arith.fastmath<none>}> : (f16, f32) -> f16 %5286 = torch.aten.argmax %5285, %int1_8048, %false_8049 : !torch.vtensor<[1,151936],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64> ``` before splitReduction ```bash %1209:2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%extracted_slice : tensor<151936xf32>) outs(%1201, %1203 : tensor<f32>, tensor<i64>) { ^bb0(%in: f32, %out: f32, %out_577: i64): %1297 = arith.truncf %in : f32 to f16 %1298 = arith.extf %1297 : f16 to f32 %1299 = linalg.index 0 : index %1300 = arith.index_cast %1299 : index to i64 %1301 = arith.maximumf %1298, %out : f32 %1302 = arith.cmpf ogt, %1298, %out : f32 %1303 = arith.select %1302, %1300, %out_577 : i64 linalg.yield %1301, %1303 : f32, i64 } -> (tensor<f32>, tensor<i64>) ``` after splitReduction ```bash %1209:2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%1208#0, %1208#1 : tensor<1187xf32>, tensor<1187xi64>) outs(%1201, %1203 : tensor<f32>, tensor<i64>) { ^bb0(%in: f32, %in_577: i64, %out: f32, %out_578: i64): %1297 = linalg.index 0 : index %1298 = arith.muli %1297, %c128 : index %1299 = arith.index_cast %1298 : index to i64 %1300 = arith.addi %1299, %in_577 : i64 %1301 = arith.maximumf %in, %out : f32 %1302 = arith.cmpf ogt, %in, %out : f32 %1303 = arith.select %1302, %1300, %out_578 : i64 linalg.yield %1301, %1303 : f32, i64 } -> (tensor<f32>, tensor<i64>) ``` This will cause iree-compile crash in the following check of mlir file because the mix of fp16 and fp32 precision. Signed-off-by: Zhengbo Wang <[email protected]>
Deepseek_R1_Distill_Qwen_1.5B has issues for fp16.
I use iree-compile to compile fp16 version of Deepseek_R1_Distill_Qwen_1.5B. It report the following error.
before splitReduction
%1209:2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%extracted_slice : tensor<151936xf32>) outs(%1201, %1203 : tensor<f32>, tensor<i64>) { ^bb0(%in: f32, %out: f32, %out_577: i64): %1297 = arith.truncf %in : f32 to f16 %1298 = arith.extf %1297 : f16 to f32 %1299 = linalg.index 0 : index %1300 = arith.index_cast %1299 : index to i64 %1301 = arith.maximumf %1298, %out : f32 %1302 = arith.cmpf ogt, %1298, %out : f32 %1303 = arith.select %1302, %1300, %out_577 : i64 linalg.yield %1301, %1303 : f32, i64 } -> (tensor<f32>, tensor<i64>)after splitReduction
%1209:2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%1208#0, %1208#1 : tensor<1187xf32>, tensor<1187xi64>) outs(%1201, %1203 : tensor<f32>, tensor<i64>) { ^bb0(%in: f32, %in_577: i64, %out: f32, %out_578: i64): %1297 = linalg.index 0 : index %1298 = arith.muli %1297, %c128 : index %1299 = arith.index_cast %1298 : index to i64 %1300 = arith.addi %1299, %in_577 : i64 %1301 = arith.maximumf %in, %out : f32 %1302 = arith.cmpf ogt, %in, %out : f32 %1303 = arith.select %1302, %1300, %out_578 : i64 linalg.yield %1301, %1303 : f32, i64 } -> (tensor<f32>, tensor<i64>)This will cause iree-compile crash in the following check of mlir file because the mix of fp16 and fp32 precision.