-
Notifications
You must be signed in to change notification settings - Fork 24.1k
Composition of torch.compile and torch.func.grad silently produces a wrong result. #136662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
cc @guilhermeleobas. Tentatively marking hi-pri for silent incorrectness. |
commenting to bump last update time: this is still an active WIP |
Thanks for the report @anatolvitold. I was able to reduce your reproducer to just: import torch
def g(x):
y = torch.zeros_like(x)
y[:, 0] += x[:, 2].abs()
y[:, 0] += x[:, 2].abs()
return y.sum()
def f(x):
return torch.func.grad(g)(x)
@torch.compile(backend='inductor', fullgraph=True)
def h(x):
return f(x)
x = torch.randn(2, 3)
print(f(x))
# tensor([[ 0., 0., -2.],
# [ 0., 0., 2.]])
print(h(x))
# tensor([[ 0., 0., -1.],
# [ 0., 0., 1.]]) One thing to notice is that the error is only reproduced when the backend is |
Thank you @guilhermeleobas for your response. Given that |
@zou3519, the culprit of it is indeed inductor. If we compile the graph generated by autograd with inductor, the result is different from eager. import torch
from torch import device
def forward(arg0_1: "f32[2, 3][3, 1]cpu"):
slice_2: "f32[2, 3][3, 1]cpu" = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 9223372036854775807)
select_1: "f32[2][3]cpu" = torch.ops.aten.select.int(slice_2, 1, 2); slice_2 = None
# File: /home/guilhermeleobas/git/pytorch/c.py:45 in g, code: y[:, 0] += x[:, 2].abs()
slice_10: "f32[2, 3][3, 1]cpu" = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 9223372036854775807); arg0_1 = None
select_7: "f32[2][3]cpu" = torch.ops.aten.select.int(slice_10, 1, 2); slice_10 = None
# File: /home/guilhermeleobas/git/pytorch/torch/_functorch/eager_transforms.py:1433 in grad_and_value_impl, code: flat_grad_input = _autograd_grad(
full_1: "f32[][]cpu" = torch.ops.aten.full.default([], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
expand: "f32[2, 3][0, 0]cpu" = torch.ops.aten.expand.default(full_1, [2, 3]); full_1 = None
new_empty_strided: "f32[2, 3][3, 1]cpu" = torch.ops.aten.new_empty_strided.default(expand, [2, 3], [3, 1])
copy_2: "f32[2, 3][3, 1]cpu" = torch.ops.aten.copy.default(new_empty_strided, expand); new_empty_strided = expand = None
as_strided_1: "f32[2][3]cpu" = torch.ops.aten.as_strided.default(copy_2, [2], [3], 0)
clone: "f32[2][1]cpu" = torch.ops.aten.clone.default(as_strided_1, memory_format = torch.contiguous_format)
full_2: "f32[2][1]cpu" = torch.ops.aten.full.default([2], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
copy_3: "f32[2][3]cpu" = torch.ops.aten.copy.default(as_strided_1, full_2); as_strided_1 = full_2 = None
as_strided_scatter: "f32[2, 3][3, 1]cpu" = torch.ops.aten.as_strided_scatter.default(copy_2, copy_3, [2], [3], 0); copy_2 = copy_3 = None
full_3: "f32[6][1]cpu" = torch.ops.aten.full.default([6], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
as_strided_3: "f32[2][3]cpu" = torch.ops.aten.as_strided.default(full_3, [2], [3], 0)
copy_4: "f32[2][3]cpu" = torch.ops.aten.copy.default(as_strided_3, clone); as_strided_3 = clone = None
as_strided_scatter_1: "f32[6][1]cpu" = torch.ops.aten.as_strided_scatter.default(full_3, copy_4, [2], [3], 0); full_3 = copy_4 = None
as_strided_6: "f32[2, 3][3, 1]cpu" = torch.ops.aten.as_strided.default(as_strided_scatter_1, [2, 3], [3, 1], 0); as_strided_scatter_1 = None
add_2: "f32[2, 3][3, 1]cpu" = torch.ops.aten.add.Tensor(as_strided_scatter, as_strided_6); as_strided_scatter = as_strided_6 = None
new_empty_strided_1: "f32[2, 3][3, 1]cpu" = torch.ops.aten.new_empty_strided.default(add_2, [2, 3], [3, 1])
copy_5: "f32[2, 3][3, 1]cpu" = torch.ops.aten.copy.default(new_empty_strided_1, add_2); new_empty_strided_1 = add_2 = None
as_strided_8: "f32[2][3]cpu" = torch.ops.aten.as_strided.default(copy_5, [2], [3], 0)
clone_1: "f32[2][1]cpu" = torch.ops.aten.clone.default(as_strided_8, memory_format = torch.contiguous_format)
copy_6: "f32[2][3]cpu" = torch.ops.aten.copy.default(as_strided_8, clone_1); as_strided_8 = None
as_strided_scatter_2: "f32[2, 3][3, 1]cpu" = torch.ops.aten.as_strided_scatter.default(copy_5, copy_6, [2], [3], 0); copy_5 = copy_6 = None
sign: "f32[2][1]cpu" = torch.ops.aten.sign.default(select_7); select_7 = None
mul: "f32[2][1]cpu" = torch.ops.aten.mul.Tensor(clone_1, sign); clone_1 = sign = None
full_4: "f32[2, 3][3, 1]cpu" = torch.ops.aten.full.default([2, 3], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
select_scatter_4: "f32[2, 3][3, 1]cpu" = torch.ops.aten.select_scatter.default(full_4, mul, 1, 2); full_4 = mul = None
full_5: "f32[2, 3][3, 1]cpu" = torch.ops.aten.full.default([2, 3], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
slice_scatter_4: "f32[2, 3][3, 1]cpu" = torch.ops.aten.slice_scatter.default(full_5, select_scatter_4, 0, 0, 9223372036854775807); full_5 = select_scatter_4 = None
new_empty_strided_2: "f32[2, 3][3, 1]cpu" = torch.ops.aten.new_empty_strided.default(as_strided_scatter_2, [2, 3], [3, 1])
copy_7: "f32[2, 3][3, 1]cpu" = torch.ops.aten.copy.default(new_empty_strided_2, as_strided_scatter_2); new_empty_strided_2 = as_strided_scatter_2 = None
as_strided_11: "f32[2][3]cpu" = torch.ops.aten.as_strided.default(copy_7, [2], [3], 0)
clone_2: "f32[2][1]cpu" = torch.ops.aten.clone.default(as_strided_11, memory_format = torch.contiguous_format)
full_6: "f32[2][1]cpu" = torch.ops.aten.full.default([2], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
copy_8: "f32[2][3]cpu" = torch.ops.aten.copy.default(as_strided_11, full_6); as_strided_11 = full_6 = None
as_strided_scatter_3: "f32[2, 3][3, 1]cpu" = torch.ops.aten.as_strided_scatter.default(copy_7, copy_8, [2], [3], 0); copy_7 = copy_8 = None
full_7: "f32[6][1]cpu" = torch.ops.aten.full.default([6], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
as_strided_13: "f32[2][3]cpu" = torch.ops.aten.as_strided.default(full_7, [2], [3], 0)
copy_9: "f32[2][3]cpu" = torch.ops.aten.copy.default(as_strided_13, clone_2); as_strided_13 = clone_2 = None
as_strided_scatter_4: "f32[6][1]cpu" = torch.ops.aten.as_strided_scatter.default(full_7, copy_9, [2], [3], 0); full_7 = copy_9 = None
as_strided_16: "f32[2, 3][3, 1]cpu" = torch.ops.aten.as_strided.default(as_strided_scatter_4, [2, 3], [3, 1], 0); as_strided_scatter_4 = None
add_3: "f32[2, 3][3, 1]cpu" = torch.ops.aten.add.Tensor(as_strided_scatter_3, as_strided_16); as_strided_scatter_3 = as_strided_16 = None
new_empty_strided_3: "f32[2, 3][3, 1]cpu" = torch.ops.aten.new_empty_strided.default(add_3, [2, 3], [3, 1])
copy_10: "f32[2, 3][3, 1]cpu" = torch.ops.aten.copy.default(new_empty_strided_3, add_3); new_empty_strided_3 = add_3 = None
as_strided_18: "f32[2][3]cpu" = torch.ops.aten.as_strided.default(copy_10, [2], [3], 0); copy_10 = None
clone_3: "f32[2][1]cpu" = torch.ops.aten.clone.default(as_strided_18, memory_format = torch.contiguous_format); as_strided_18 = None
sign_1: "f32[2][1]cpu" = torch.ops.aten.sign.default(select_1); select_1 = None
mul_1: "f32[2][1]cpu" = torch.ops.aten.mul.Tensor(clone_3, sign_1); clone_3 = sign_1 = None
full_8: "f32[2, 3][3, 1]cpu" = torch.ops.aten.full.default([2, 3], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
select_scatter_5: "f32[2, 3][3, 1]cpu" = torch.ops.aten.select_scatter.default(full_8, mul_1, 1, 2); full_8 = mul_1 = None
full_9: "f32[2, 3][3, 1]cpu" = torch.ops.aten.full.default([2, 3], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
slice_scatter_5: "f32[2, 3][3, 1]cpu" = torch.ops.aten.slice_scatter.default(full_9, select_scatter_5, 0, 0, 9223372036854775807); full_9 = select_scatter_5 = None
add_4: "f32[2, 3][3, 1]cpu" = torch.ops.aten.add.Tensor(slice_scatter_4, slice_scatter_5); slice_scatter_4 = slice_scatter_5 = None
return (add_4,)
@torch.compile(backend='inductor', fullgraph=True)
def j(x):
return forward(x)
x = torch.randn(2, 3)
print(j(x))
# (tensor([[0., 0., 1.],
# [0., 0., 1.]]),) |
@masnesral and @guilhermeleobas, Could you please assign someone else to this ticket, so it will not fall through the cracks? cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov @rec |
I think the fact that it has "module: inductor" means it's being tracked and is on the backlog, but cc @eellison |
(Apparent) root cause here is
cc @Chillee. repros on cuda and cpu. |
cc: @zou3519 |
cc @laithsakka , @Chillee , @zou3519 who would be the appropriate owner of |
Laith and I only really know our way around the reinplacing logic for custom operators and user-defined triton kernels. Reinplacing for regular built-in pytorch operators goes down another path |
okay assigning to @Chillee as owner of that pass |
Taking a look at this now. |
hi @Chillee Any update on this? |
Got a smaller repro: import torch
torch.manual_seed(0)
x = torch.randn(2, 3)
def forward(x_1):
# ones
a = torch.ones([2, 3])
c = torch.ones(2)
a[:, 0].copy_(c)
d = a.clone()
e = torch.ops.aten.as_strided.default(d, [2], [3], 0)
f = e.clone()
g = torch.zeros(2)
e.copy_(g)
h = torch.zeros(2, 3)
h[:, 0].copy_(f)
add_1 = d + h
return add_1
print(forward(x))
print(torch.compile(forward, fullgraph=True, backend="inductor")(x)) |
I know what is wrong, I just don't know how to fix it. Here's my attempt at fixing it (which appears to fix the test case, but breaks the rest of inductor: #152011). The problem is the reinplacing pass is actually 3 graph passes. The first one, canonicalize_view_scatter_ops, adds new nodes into the graph. The second one, reinplace_inplaceable_ops, requires accurate aliasing information on the FakeTensors. canonicalize_view_scatter_ops does add accurate FakeTensor information to the new nodes it is adding, but it does not update other nodes that depend on the new nodes. As so the fix I tried was to run FakeTensorUpdator in between those two passes, but it didn't quite work. |
🐛 Describe the bug
The following code sample shows the case when the composition of
torch.compile
andtorch.func.grad
of a function silently produces a wrong result.The output of the JIT-compiled version of the function itself is the same as the output of the uncompiled version of the function.
Curiously, if one comments out the last line before the return statement in the test function
test_func
, namelyprocessed_vars = (processed_vars.unsqueeze(-2) @ dependency_matrix).squeeze(dim=-2)
, the results of the compiled and uncompiled gradients are the same.Error logs
No response
Minified repro
No response
Versions
Collecting environment information...
PyTorch version: 2.4.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.2.0-23ubuntu4) 13.2.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.39
Python version: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:50:58) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-6.11.0-1-t2-noble-x86_64-with-glibc2.39
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 39 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: GenuineIntel
Model name: Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz
CPU family: 6
Model: 158
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 13
CPU(s) scaling MHz: 74%
CPU max MHz: 5000.0000
CPU min MHz: 800.0000
BogoMIPS: 4800.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx rdseed adx smap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm ida arat pln pts hwp hwp_act_window hwp_epp vnmi md_clear flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 256 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 2 MiB (8 instances)
L3 cache: 16 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-15
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Mitigation; Microcode
Vulnerability Tsx async abort: Mitigation; TSX disabled
Versions of relevant libraries:
[pip3] flake8==7.0.0
[pip3] mypy==1.10.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] numpydoc==1.7.0
[pip3] optree==0.12.1
[pip3] torch==2.4.1
[pip3] torchaudio==2.4.1
[pip3] torchopt==0.7.3
[pip3] torchvision==0.19.1
[conda] _anaconda_depends 2024.06 py312_mkl_2
[conda] blas 1.0 mkl
[conda] cpuonly 2.0 0 pytorch
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] mkl-service 2.4.0 py312h5eee18b_1
[conda] mkl_fft 1.3.10 py312h5eee18b_0
[conda] mkl_random 1.2.7 py312h526ad5a_0
[conda] numpy 1.26.4 py312hc5e2394_0
[conda] numpy-base 1.26.4 py312h0da6c21_0
[conda] numpydoc 1.7.0 py312h06a4308_0
[conda] optree 0.12.1 pypi_0 pypi
[conda] pytorch 2.4.1 py3.12_cpu_0 pytorch
[conda] pytorch-mutex 1.0 cpu pytorch
[conda] torchaudio 2.4.1 py312_cpu pytorch
[conda] torchopt 0.7.3 pypi_0 pypi
[conda] torchvision 0.19.1 py312_cpu pytorch
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov @BoyuanFeng @rec
The text was updated successfully, but these errors were encountered: