mul: remove opmath cast sequence#9663
Conversation
|
Fixes issue #9662 |
|
@ysiraichi Hi, could you take a look when you have a chance? |
|
Could you add a test to verify whether that does what we expect? |
I could add a test for the mul op, but as you know, this patch affects the intermediate type casting, so it’s tricky to capture the cast → op → cast-back sequence. Do you have any ideas on how to validate this? |
|
You can use |
Remove the explicit opmath-driven cast chain (bf16→f32→bf16, etc.) from `mul`. The op now executes in the dtype chosen by standard dtype promotion, without inserting unconditional upcast/downcast steps. But leave its functionality for future usage.
f22abe8 to
55c07ed
Compare
|
@ysiraichi ,I was going to try the |
ysiraichi
left a comment
There was a problem hiding this comment.
Thank you for the PR.
|
@sshonTT @mmanzoorTT looks like we may need to revert this change due to #9699 . |
This reverts commit 2a9138a.
This reverts commit 2a9138a.
This reverts commit 2a9138a.
Commit 2a9138a removed `.use_opmathtype_for_compute()` from element-wise 'mul' operation, this breaks mixed-precision accumulation behavior expected by the Neuron compiler that traces/compile on CPU and later execute the binary on neuron hardwares, causing accuracy degradation transformer models using mixed-precision compilation Reverts: commit 2a9138a, other changes are result of rebase from r2.9 Fixes: Model accuracy failures with mixed-precision accumulation #9699
Remove the explicit opmath-driven cast chain (bf16→f32→bf16, etc.) from
mul. The op now executes in the dtype chosen by standard dtype promotion, without inserting unconditional upcast/downcast steps. But leave its functionality for future usage.