Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[dynamic shapes] data-dependent error when backed + unbacked expression resolves statically #151491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
pianpwk opened this issue Apr 17, 2025 · 1 comment
Assignees
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@pianpwk
Copy link
Contributor

pianpwk commented Apr 17, 2025

πŸ› Describe the bug

Reported by @ColinPeppler

Getting this log, suggesting we can simplify the expression to False with the backed hint, but still data-dependent errors out:

torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression False (unhinted: Ne(Mod(18*u0, ((s58*u0)//8)), 0)).  (Size-like symbols: none)

Caused by: (_refs/__init__.py:3806 in _reshape_view_helper)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=""
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "/data/users/pianpwk/pytorch/custom_tests/test_s0_u0.py", line 11, in forward
    return y.view(-1, 144)

To fix the error, insert one of the following checks before this call:
  1. torch._check(False)
  2. torch._check(True)

Repro:

import torch

from torch.export import export, Dim

class Foo(torch.nn.Module):
    def forward(self, a, b):
        u0 = a.item()
        y = torch.zeros(u0, 18, b.shape[0])
        torch._check((u0 * 18 * b.shape[0]) // 144 != u0)
        torch._check(u0 % ((u0 * 18 * b.shape[0]) // 144) != 0) 
        return y.view(-1, 144)

ep = export(
    Foo(),
    (torch.tensor([6]), torch.randn(8)),
    dynamic_shapes={
        "a": None,
        "b": (Dim.DYNAMIC,),
    },
)

Versions

latest nightly

cc @chauhang @penguinwu @ezyang @bobrenjc93

@pianpwk pianpwk self-assigned this Apr 17, 2025
@mlazos mlazos added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 21, 2025
@ysiraichi
Copy link
Collaborator

I think your example, specifically, should not get to that error, indeed. That's because the first torch._checks((u0 * 18 * b.shape[0]) // 144 != u0) is actually False when replacing u0 with its hints. That said, I was able to get to that same error without both torch._check().

The error is actually coming from the view refs implementation. While I'm not sure that guard was actually necessary in this view call, the unbacked expression does not resolved statically.

The reason why you are seeing the simplified expression False in the error message (and in the fix suggestion) is because we are using the hinted unbacked expression. Basically, when we replace s58 by its hint, the expression resolves statically.

In summary, I believe the thing that's incorrect here is the error message. We should check whenever this happens, and print only the unhinted expression.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants