Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[FlexAttention] export fails to trace with functorch #153063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weโ€™ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
tugsbayasgalan opened this issue May 7, 2025 · 0 comments
Open

[FlexAttention] export fails to trace with functorch #153063

tugsbayasgalan opened this issue May 7, 2025 · 0 comments
Labels
module: flex attention module: functorch Pertaining to torch.func or pytorch/functorch oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@tugsbayasgalan
Copy link
Contributor

tugsbayasgalan commented May 7, 2025

๐Ÿ› Describe the bug

import torch
import torch.nn as nn
from torch.func import vmap
from torch.export import export

# 1. Inner model (shared across batch)
class TinyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(8, 4)

    def forward(self, x):
        return torch.relu(self.linear(x))

# 2. Module that applies vmap over inner model
class BatchedModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = TinyModel()

    def forward(self, x):
        return vmap(self.model)(x)  # vectorize over batch

# 3. Instantiate and test export
x = torch.randn(16, 8)
model = BatchedModel().eval()

# 4. Export
graph_module = export(model, (x,))
print(graph_module.module)

Errors with
in _free_unbacked_symbols_with_path(a, path, real, shape_env, pending, simplify)
   1024 elif isinstance(a, torch.Tensor):
   1025     from torch._subclasses.fake_tensor import FakeTensor
-> 1027     assert isinstance(a, FakeTensor)
   1028     r.update(
   1029         go(
   1030             a.size(),
   (...)
   1033         )
   1034     )
   1035     if a.layout not in [
   1036         torch.sparse_csr,
   1037         torch.sparse_csc,
   1038         torch.sparse_bsr,
   1039         torch.sparse_bsc,
   1040     ]:
AssertionError: 

It seems to me that at pre-dispatch level, we are not properly peeking into fake tensor inside BatchedTensor

Versions

main

cc @chauhang @penguinwu @zou3519 @Chillee @samdow @kshitij12345 @ydwu4 @drisspg @yanboliang @BoyuanFeng

@tugsbayasgalan tugsbayasgalan changed the title export fails to trace with functorch [FlexAttention] export fails to trace with functorch May 7, 2025
@janeyx99 janeyx99 added module: functorch Pertaining to torch.func or pytorch/functorch module: flex attention labels May 7, 2025
@pytorch-bot pytorch-bot bot added module: higher order operators torch.cond and similar oncall: pt2 module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, labels May 7, 2025
@janeyx99 janeyx99 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed oncall: pt2 module: higher order operators torch.cond and similar module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, labels May 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: flex attention module: functorch Pertaining to torch.func or pytorch/functorch oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants