Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[inductor] proxy_tensor.py throws SyntaxError when using .random_ #151432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
shaoyuyoung opened this issue Apr 16, 2025 · 2 comments
Open

[inductor] proxy_tensor.py throws SyntaxError when using .random_ #151432

shaoyuyoung opened this issue Apr 16, 2025 · 2 comments
Labels
dynamo-triage-jan2025 high priority module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@shaoyuyoung
Copy link
Contributor

shaoyuyoung commented Apr 16, 2025

πŸ› Describe the bug

symptom: proxy_tensor.py throws SyntaxError when using .random_
device backend: both CPP and triton
repro

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._inductor import config

config.fallback_random = True
torch.set_grad_enabled(False)


class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = x * x.random_(0, 2)
        return x


model = Model()


x = torch.randn(4, 8)


inputs = [x]


def run_test(model, inputs, backend):
    torch.manual_seed(0)
    if backend != "eager":
        model = torch.compile(model, backend=backend)
    try:
        output = model(*inputs)
        print(f"succeed on {backend}")
    except Exception as e:
        print(e)


run_test(model, inputs, 'eager')
run_test(model, inputs, 'inductor')

Error logs

eager

succeed on eager

inductor

SyntaxError: invalid syntax (proxy_tensor.py:1265 in wrapped, line 5)

Versions

nightly 20250414

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @bdhirsh @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @muchulee8 @amjames @aakhundov

@shaoyuyoung
Copy link
Contributor Author

It seems that the proxy_tensor.py is changing quickly. I am not sure whether this problem appears on master branch now.

@mlazos
Copy link
Contributor

mlazos commented Apr 21, 2025

still occurs on efdcc98

@mlazos mlazos added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, and removed module: inductor triage review labels Apr 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo-triage-jan2025 high priority module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants