Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[inductor] [aot] torch.linalg.lu can't accept slice operation, behaving differently with eager #151401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
shaoyuyoung opened this issue Apr 16, 2025 · 0 comments
Assignees
Labels
high priority module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@shaoyuyoung
Copy link
Contributor

shaoyuyoung commented Apr 16, 2025

πŸ› Describe the bug

symptom: torch.linalg.lu can't accept slice operation, behaving differently with eager. As you can see, I use [:2] to get P, L. I can do this successfully on eager but aot throws dynamic_attributes error.
device backend: both CPP and triton
repro

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._inductor import config

config.fallback_random = True
torch.set_grad_enabled(False)


class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        P, L = torch.linalg.lu(x)[:2]
        return P, L


model = Model()


x = torch.randn(2, 4, 3, 3)

inputs = [x]


def run_test(model, inputs, backend):
    torch.manual_seed(0)
    if backend != "eager":
        model = torch.compile(model, backend=backend)
    try:
        output = model(*inputs)
        print(f"succeed on {backend}")
    except Exception as e:
        print(e)


run_test(model, inputs, 'eager')
run_test(model, inputs, 'aot_eager')

Error logs

eager

succeed on eager

aot_eager

TypeError: VariableTracker.__init__() got an unexpected keyword argument 'dynamic_attributes'

Versions

nightly 20250414

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @amjames

@shaoyuyoung shaoyuyoung changed the title [inductor] [aot] torch.linalg.lu can't accept **slice operation**, behaving differently with eager [inductor] [aot] torch.linalg.lu can't accept slice operation, behaving differently with eager Apr 16, 2025
@masnesral masnesral added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants