Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[PT2] torch.layer_norm errors in eager but runs fine in backend=aot_eager_decomp_partition #151478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
weifengpy opened this issue Apr 16, 2025 · 2 comments
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: decompositions Topics related to decomposition (excluding PrimTorch) module: error checking Bugs related to incorrect/lacking error checking oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@weifengpy
Copy link
Contributor

weifengpy commented Apr 16, 2025

πŸš€ The feature, motivation and pitch

torch.layer_norm throws error when input and weight are in different dtypes. However, it runs fine with backend=aot_eager_decomp_partition, because of decomposation of torch.layer_norm into fp32 ops

we run into this because online job disable pt2, but offline training requires pt2. ideally we want the same behavior across eager and compile

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @malfet @chauhang @penguinwu @SherlockNoMad @bdhirsh

# python test_layer_norm.py
import torch

def forward(input):
    normalized_shape = (4, )
    weight = torch.ones(4, device="cuda")
    bias = torch.ones(4, device="cuda")
    eps = 0.1
    output = torch.layer_norm(
        input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled
    )
    return output


x = torch.tensor([[1.0, 2.0, 3.0, 4.0],
              [2.0, 4.0, 6.0, 8.0]], device="cuda")

# no error
forward_compiled = torch.compile(forward, backend="aot_eager_decomp_partition")
forward_compiled(x.to(torch.bfloat16))

# error
forward_compiled = torch.compile(forward, backend="aot_eager")
forward_compiled(x.to(torch.bfloat16))

# error
# forward(x.to(torch.bfloat16))

error

RuntimeError: expected scalar type BFloat16 but found Float

While executing %native_layer_norm : [num_users=1] = call_function[target=torch.ops.aten.native_layer_norm.default](args = (%arg0_1, [4], %ones, %ones_1, 0.1), kwargs = {})
GraphModule: class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "bf16[2, 4][4, 1]"):
         # File: /data/users/weif/pytorch/test_layer_norm.py:5 in forward, code: weight = torch.ones(4, device="cuda")
        ones: "f32[4][1]" = torch.ops.aten.ones.default([4], device = device(type='cuda'), pin_memory = False)

         # File: /data/users/weif/pytorch/test_layer_norm.py:6 in forward, code: bias = torch.ones(4, device="cuda")
        ones_1: "f32[4][1]" = torch.ops.aten.ones.default([4], device = device(type='cuda'), pin_memory = False)

         # File: /data/users/weif/pytorch/test_layer_norm.py:8 in forward, code: output = torch.layer_norm(
        native_layer_norm = torch.ops.aten.native_layer_norm.default(arg0_1, [4], ones, ones_1, 0.1);  arg0_1 = ones = ones_1 = None
        getitem: "bf16[2, 4][4, 1]" = native_layer_norm[0];  native_layer_norm = None
        return (getitem,)

Alternatives

No response

Additional context

No response

@bdhirsh bdhirsh added module: error checking Bugs related to incorrect/lacking error checking oncall: pt2 module: decompositions Topics related to decomposition (excluding PrimTorch) labels Apr 16, 2025
@bdhirsh
Copy link
Contributor

bdhirsh commented Apr 16, 2025

@weifengpy it sounds like there is a dtype assertion that runs in eager that we're missing in compile. We should probably fix that. Just to confirm - it sounds like you want compile to error here in the same way that eager errors?

@weifengpy
Copy link
Contributor Author

it sounds like there is a dtype assertion

it is just dtype assertion? I thought eager cannot run because of mixed dtypes in cuda kernels

sounds like you want compile to error here in the same way that eager errors?

is warning an option? I am worried about breaking internal jobs by throwing hard errors. I am more worried about silent dtype casting

@mlazos mlazos added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix high priority labels Apr 22, 2025
@masnesral masnesral added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: decompositions Topics related to decomposition (excluding PrimTorch) module: error checking Bugs related to incorrect/lacking error checking oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants