-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Closed
Labels
actionablehigh prioritymodule: correctness (silent)issue that returns an incorrect result silentlyissue that returns an incorrect result silentlymodule: inductoroncall: pt2pt2: ubn"unbreak now" hi-pri, only applies to the PyTorch Compiler Team."unbreak now" hi-pri, only applies to the PyTorch Compiler Team.triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
min repro:
import torch
def f(x):
return torch.full_like(x, 3)
x = torch.randn(4, 5, 6).transpose(1, -1)
out = f(x)
out_compiled = torch.compile(f, backend="aot_eager_decomp_partition")(x)
print(out.stride())
print(out_compiled.stride())
# prints
# (30, 1, 6)
# (30, 5, 1)
This seems like the root cause of an NJT compile crash that @jbschlosser was running into (see his repro, njt_patch and error)
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov
shaoyuyoung and Skylion007jbschlosser and Skylion007
Metadata
Metadata
Assignees
Labels
actionablehigh prioritymodule: correctness (silent)issue that returns an incorrect result silentlyissue that returns an incorrect result silentlymodule: inductoroncall: pt2pt2: ubn"unbreak now" hi-pri, only applies to the PyTorch Compiler Team."unbreak now" hi-pri, only applies to the PyTorch Compiler Team.triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module