[PT2] torch.layer_norm errors in eager but runs fine in backend=aot_eager_decomp_partition #151478
Labels
enhancement
Not as big of a feature, but technically not a bug. Should be easy to fix
module: decompositions
Topics related to decomposition (excluding PrimTorch)
module: error checking
Bugs related to incorrect/lacking error checking
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
π The feature, motivation and pitch
torch.layer_norm throws error when input and weight are in different dtypes. However, it runs fine with backend=aot_eager_decomp_partition, because of decomposation of torch.layer_norm into fp32 ops
we run into this because online job disable pt2, but offline training requires pt2. ideally we want the same behavior across eager and compile
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @malfet @chauhang @penguinwu @SherlockNoMad @bdhirsh
error
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: