Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Composition of torch.compile and torch.func.grad silently produces a wrong result. #136662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
anatolvitold opened this issue Sep 25, 2024 · 16 comments
Assignees
Labels
actionable high priority module: correctness (silent) issue that returns an incorrect result silently module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module ubn "unbreak now", our utmost priority label.

Comments

@anatolvitold
Copy link

anatolvitold commented Sep 25, 2024

🐛 Describe the bug

The following code sample shows the case when the composition of torch.compile and torch.func.grad of a function silently produces a wrong result.

The output of the JIT-compiled version of the function itself is the same as the output of the uncompiled version of the function.

Curiously, if one comments out the last line before the return statement in the test function test_func, namely processed_vars = (processed_vars.unsqueeze(-2) @ dependency_matrix).squeeze(dim=-2), the results of the compiled and uncompiled gradients are the same.

import torch

fixed_values = torch.tensor([[    0.0000,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000],
                             [    0.0000,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000],
                             [    0.0000,     0.0000,     0.0000,     0.0000,     0.0000,     0.0000],
                             [-1068.6385,     0.0000,    65.0000,     0.0000, -torch.inf,     0.0000]],
                            dtype=torch.float64)

free_vars_mask = torch.tensor([[ True,  True, False,  True,  True,  True],
                               [ True, False,  True,  True,  True,  True],
                               [ True,  True,  True,  True,  True,  True],
                               [False, False, False, False, False,  True]])

dependency_matrix = torch.tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                                   [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                                   [0.0000, 0.0000, 1.0000, 2.0000, 0.0000, 0.0000],
                                   [0.0000, 0.0000, 0.5000, 1.0000, 0.0000, 0.0000],
                                   [0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000],
                                   [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]],
                        
                                  [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                                   [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                                   [0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000],
                                   [0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000],
                                   [0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000],
                                   [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]],
                          
                                  [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                                   [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                                   [0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000],
                                   [0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000],
                                   [0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000],
                                   [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]],
                          
                                  [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                                   [0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
                                   [0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000],
                                   [0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000],
                                   [0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000],
                                   [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000]]],
                                 dtype=torch.float64)

free_vars_linear_indices = torch.where(free_vars_mask.ravel())[0]
free_vars_indices = tuple(map(lambda x: x.detach().clone(),\
                              torch.unravel_index(free_vars_linear_indices,\
                                                  free_vars_mask.shape)))

test_input = torch.tensor([-218.5399,    3.1056,   21.8333,    4.1535,    0.2   ,  144.8986,
                             49.6429,   60.1429,    3.9028,    0.59  ,  126.218 ,   -6.0392,
                             98.5   ,   35.5714,    4.8792,    0.2   ,    0.01  ],
                          dtype=torch.float64)

def to_constrained_params(vars):
    
    two_over_pi = torch.tensor(2.0 / torch.pi, dtype=vars.dtype)
    one = torch.tensor(1.0, dtype=vars.dtype)
    hundred = torch.tensor(100.0, dtype=vars.dtype)
    fifty = torch.tensor(50.0, dtype=vars.dtype)
    
    processed_vars = torch.zeros_like(vars)

    processed_vars[:, 0] += (two_over_pi * torch.arctan(vars[:, 0] / hundred) + one)*fifty
    processed_vars[:, 2] += torch.abs(vars[:, 2])
    processed_vars[:, 3] += torch.abs(vars[:, 3])
    processed_vars[:, 4] += torch.exp(vars[:, 4])

    processed_vars[:, 1] += torch.abs(vars[:, 1])
    processed_vars[[2], 1] *= -1
    
    weights = torch.abs(vars[:, 5])
    weights = weights / torch.sum(weights)
    processed_vars[:, 5] += weights

    return processed_vars
    
def test_func(free_vars):

    # Prepare free variable by placing them at the right indices of the All Variables Matrix
    free_vars_same_shape = torch.zeros_like(fixed_values, dtype=free_vars.dtype)
    free_vars_same_shape[free_vars_indices] += free_vars

    # Create All Variables Matrix by merging free variable and fixed variables
    # and performing Unconstrained-To-Constrained transformation
    processed_vars = to_constrained_params(fixed_values + free_vars_same_shape)

    processed_vars = (processed_vars.unsqueeze(-2) @ dependency_matrix).squeeze(dim=-2)
    
    return processed_vars.sum()

# the Test Function without JIT compilation
print(test_func(test_input))
# tensor(760.4707, dtype=torch.float64)

# the Test Function with JIT compilation - Same Output
print(torch.compile(test_func, fullgraph=True)(test_input))
# tensor(760.4707, dtype=torch.float64)


# The gradient of the test function without JIT
print(torch.func.grad(test_func, argnums=0)(test_input))
# tensor([5.5109e-02, 1.0000e+00, 1.5000e+00, 6.3656e+01, 0.0000e+00, 1.0270e-01,
#         1.0000e+00, 1.0000e+00, 4.9541e+01, 0.0000e+00, 1.2275e-01, 1.0000e+00,
#         1.0000e+00, 1.0000e+00, 1.3153e+02, 0.0000e+00, 0.0000e+00],
#        dtype=torch.float64)

# The gradient of the test function with JIT
print(torch.compile(torch.func.grad(test_func, argnums=0), fullgraph=True)(test_input))
# tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
#        dtype=torch.float64)

Error logs

No response

Minified repro

No response

Versions

Collecting environment information...
PyTorch version: 2.4.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.2.0-23ubuntu4) 13.2.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.39

Python version: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:50:58) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-6.11.0-1-t2-noble-x86_64-with-glibc2.39
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 39 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: GenuineIntel
Model name: Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz
CPU family: 6
Model: 158
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 13
CPU(s) scaling MHz: 74%
CPU max MHz: 5000.0000
CPU min MHz: 800.0000
BogoMIPS: 4800.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx rdseed adx smap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm ida arat pln pts hwp hwp_act_window hwp_epp vnmi md_clear flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 256 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 2 MiB (8 instances)
L3 cache: 16 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-15
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Mitigation; Microcode
Vulnerability Tsx async abort: Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] flake8==7.0.0
[pip3] mypy==1.10.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] numpydoc==1.7.0
[pip3] optree==0.12.1
[pip3] torch==2.4.1
[pip3] torchaudio==2.4.1
[pip3] torchopt==0.7.3
[pip3] torchvision==0.19.1
[conda] _anaconda_depends 2024.06 py312_mkl_2
[conda] blas 1.0 mkl
[conda] cpuonly 2.0 0 pytorch
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] mkl-service 2.4.0 py312h5eee18b_1
[conda] mkl_fft 1.3.10 py312h5eee18b_0
[conda] mkl_random 1.2.7 py312h526ad5a_0
[conda] numpy 1.26.4 py312hc5e2394_0
[conda] numpy-base 1.26.4 py312h0da6c21_0
[conda] numpydoc 1.7.0 py312h06a4308_0
[conda] optree 0.12.1 pypi_0 pypi
[conda] pytorch 2.4.1 py3.12_cpu_0 pytorch
[conda] pytorch-mutex 1.0 cpu pytorch
[conda] torchaudio 2.4.1 py312_cpu pytorch
[conda] torchopt 0.7.3 pypi_0 pypi
[conda] torchvision 0.19.1 py312_cpu pytorch

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov @BoyuanFeng @rec

@zou3519
Copy link
Contributor

zou3519 commented Sep 25, 2024

cc @guilhermeleobas. Tentatively marking hi-pri for silent incorrectness.

@zou3519 zou3519 added module: dynamo dynamo-functorch Issues related to dynamo/compile on functorch transforms labels Sep 25, 2024
@zou3519 zou3519 self-assigned this Oct 1, 2024
@zou3519 zou3519 added triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Oct 1, 2024
@masnesral
Copy link
Contributor

commenting to bump last update time: this is still an active WIP

@guilhermeleobas
Copy link
Collaborator

guilhermeleobas commented Nov 8, 2024

Thanks for the report @anatolvitold. I was able to reduce your reproducer to just:

import torch

def g(x):
    y = torch.zeros_like(x)
    y[:, 0] += x[:, 2].abs()
    y[:, 0] += x[:, 2].abs()
    return y.sum()


def f(x):
    return torch.func.grad(g)(x)


@torch.compile(backend='inductor', fullgraph=True)
def h(x):
    return f(x)

x = torch.randn(2, 3)
print(f(x))
# tensor([[ 0.,  0., -2.],
#         [ 0.,  0.,  2.]])
print(h(x))
# tensor([[ 0.,  0., -1.],
#         [ 0.,  0.,  1.]])

One thing to notice is that the error is only reproduced when the backend is inductor.

@anatolvitold
Copy link
Author

Thank you @guilhermeleobas for your response. Given that inductor is the main backend for CPU, it is still quite unfortunate. Could you please provide some inside into the causes of silent incorrectness?

@guilhermeleobas
Copy link
Collaborator

@zou3519, the culprit of it is indeed inductor. If we compile the graph generated by autograd with inductor, the result is different from eager.

import torch
from torch import device


def forward(arg0_1: "f32[2, 3][3, 1]cpu"):
    slice_2: "f32[2, 3][3, 1]cpu" = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 9223372036854775807)
    select_1: "f32[2][3]cpu" = torch.ops.aten.select.int(slice_2, 1, 2);  slice_2 = None

    # File: /home/guilhermeleobas/git/pytorch/c.py:45 in g, code: y[:, 0] += x[:, 2].abs()
    slice_10: "f32[2, 3][3, 1]cpu" = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 9223372036854775807);  arg0_1 = None
    select_7: "f32[2][3]cpu" = torch.ops.aten.select.int(slice_10, 1, 2);  slice_10 = None

    # File: /home/guilhermeleobas/git/pytorch/torch/_functorch/eager_transforms.py:1433 in grad_and_value_impl, code: flat_grad_input = _autograd_grad(
    full_1: "f32[][]cpu" = torch.ops.aten.full.default([], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    expand: "f32[2, 3][0, 0]cpu" = torch.ops.aten.expand.default(full_1, [2, 3]);  full_1 = None
    new_empty_strided: "f32[2, 3][3, 1]cpu" = torch.ops.aten.new_empty_strided.default(expand, [2, 3], [3, 1])
    copy_2: "f32[2, 3][3, 1]cpu" = torch.ops.aten.copy.default(new_empty_strided, expand);  new_empty_strided = expand = None
    as_strided_1: "f32[2][3]cpu" = torch.ops.aten.as_strided.default(copy_2, [2], [3], 0)
    clone: "f32[2][1]cpu" = torch.ops.aten.clone.default(as_strided_1, memory_format = torch.contiguous_format)
    full_2: "f32[2][1]cpu" = torch.ops.aten.full.default([2], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    copy_3: "f32[2][3]cpu" = torch.ops.aten.copy.default(as_strided_1, full_2);  as_strided_1 = full_2 = None
    as_strided_scatter: "f32[2, 3][3, 1]cpu" = torch.ops.aten.as_strided_scatter.default(copy_2, copy_3, [2], [3], 0);  copy_2 = copy_3 = None
    full_3: "f32[6][1]cpu" = torch.ops.aten.full.default([6], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    as_strided_3: "f32[2][3]cpu" = torch.ops.aten.as_strided.default(full_3, [2], [3], 0)
    copy_4: "f32[2][3]cpu" = torch.ops.aten.copy.default(as_strided_3, clone);  as_strided_3 = clone = None
    as_strided_scatter_1: "f32[6][1]cpu" = torch.ops.aten.as_strided_scatter.default(full_3, copy_4, [2], [3], 0);  full_3 = copy_4 = None
    as_strided_6: "f32[2, 3][3, 1]cpu" = torch.ops.aten.as_strided.default(as_strided_scatter_1, [2, 3], [3, 1], 0);  as_strided_scatter_1 = None
    add_2: "f32[2, 3][3, 1]cpu" = torch.ops.aten.add.Tensor(as_strided_scatter, as_strided_6);  as_strided_scatter = as_strided_6 = None
    new_empty_strided_1: "f32[2, 3][3, 1]cpu" = torch.ops.aten.new_empty_strided.default(add_2, [2, 3], [3, 1])
    copy_5: "f32[2, 3][3, 1]cpu" = torch.ops.aten.copy.default(new_empty_strided_1, add_2);  new_empty_strided_1 = add_2 = None
    as_strided_8: "f32[2][3]cpu" = torch.ops.aten.as_strided.default(copy_5, [2], [3], 0)
    clone_1: "f32[2][1]cpu" = torch.ops.aten.clone.default(as_strided_8, memory_format = torch.contiguous_format)
    copy_6: "f32[2][3]cpu" = torch.ops.aten.copy.default(as_strided_8, clone_1);  as_strided_8 = None
    as_strided_scatter_2: "f32[2, 3][3, 1]cpu" = torch.ops.aten.as_strided_scatter.default(copy_5, copy_6, [2], [3], 0);  copy_5 = copy_6 = None
    sign: "f32[2][1]cpu" = torch.ops.aten.sign.default(select_7);  select_7 = None
    mul: "f32[2][1]cpu" = torch.ops.aten.mul.Tensor(clone_1, sign);  clone_1 = sign = None
    full_4: "f32[2, 3][3, 1]cpu" = torch.ops.aten.full.default([2, 3], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    select_scatter_4: "f32[2, 3][3, 1]cpu" = torch.ops.aten.select_scatter.default(full_4, mul, 1, 2);  full_4 = mul = None
    full_5: "f32[2, 3][3, 1]cpu" = torch.ops.aten.full.default([2, 3], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    slice_scatter_4: "f32[2, 3][3, 1]cpu" = torch.ops.aten.slice_scatter.default(full_5, select_scatter_4, 0, 0, 9223372036854775807);  full_5 = select_scatter_4 = None
    new_empty_strided_2: "f32[2, 3][3, 1]cpu" = torch.ops.aten.new_empty_strided.default(as_strided_scatter_2, [2, 3], [3, 1])
    copy_7: "f32[2, 3][3, 1]cpu" = torch.ops.aten.copy.default(new_empty_strided_2, as_strided_scatter_2);  new_empty_strided_2 = as_strided_scatter_2 = None
    as_strided_11: "f32[2][3]cpu" = torch.ops.aten.as_strided.default(copy_7, [2], [3], 0)
    clone_2: "f32[2][1]cpu" = torch.ops.aten.clone.default(as_strided_11, memory_format = torch.contiguous_format)
    full_6: "f32[2][1]cpu" = torch.ops.aten.full.default([2], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    copy_8: "f32[2][3]cpu" = torch.ops.aten.copy.default(as_strided_11, full_6);  as_strided_11 = full_6 = None
    as_strided_scatter_3: "f32[2, 3][3, 1]cpu" = torch.ops.aten.as_strided_scatter.default(copy_7, copy_8, [2], [3], 0);  copy_7 = copy_8 = None
    full_7: "f32[6][1]cpu" = torch.ops.aten.full.default([6], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    as_strided_13: "f32[2][3]cpu" = torch.ops.aten.as_strided.default(full_7, [2], [3], 0)
    copy_9: "f32[2][3]cpu" = torch.ops.aten.copy.default(as_strided_13, clone_2);  as_strided_13 = clone_2 = None
    as_strided_scatter_4: "f32[6][1]cpu" = torch.ops.aten.as_strided_scatter.default(full_7, copy_9, [2], [3], 0);  full_7 = copy_9 = None
    as_strided_16: "f32[2, 3][3, 1]cpu" = torch.ops.aten.as_strided.default(as_strided_scatter_4, [2, 3], [3, 1], 0);  as_strided_scatter_4 = None
    add_3: "f32[2, 3][3, 1]cpu" = torch.ops.aten.add.Tensor(as_strided_scatter_3, as_strided_16);  as_strided_scatter_3 = as_strided_16 = None
    new_empty_strided_3: "f32[2, 3][3, 1]cpu" = torch.ops.aten.new_empty_strided.default(add_3, [2, 3], [3, 1])
    copy_10: "f32[2, 3][3, 1]cpu" = torch.ops.aten.copy.default(new_empty_strided_3, add_3);  new_empty_strided_3 = add_3 = None
    as_strided_18: "f32[2][3]cpu" = torch.ops.aten.as_strided.default(copy_10, [2], [3], 0);  copy_10 = None
    clone_3: "f32[2][1]cpu" = torch.ops.aten.clone.default(as_strided_18, memory_format = torch.contiguous_format);  as_strided_18 = None
    sign_1: "f32[2][1]cpu" = torch.ops.aten.sign.default(select_1);  select_1 = None
    mul_1: "f32[2][1]cpu" = torch.ops.aten.mul.Tensor(clone_3, sign_1);  clone_3 = sign_1 = None
    full_8: "f32[2, 3][3, 1]cpu" = torch.ops.aten.full.default([2, 3], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    select_scatter_5: "f32[2, 3][3, 1]cpu" = torch.ops.aten.select_scatter.default(full_8, mul_1, 1, 2);  full_8 = mul_1 = None
    full_9: "f32[2, 3][3, 1]cpu" = torch.ops.aten.full.default([2, 3], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
    slice_scatter_5: "f32[2, 3][3, 1]cpu" = torch.ops.aten.slice_scatter.default(full_9, select_scatter_5, 0, 0, 9223372036854775807);  full_9 = select_scatter_5 = None
    add_4: "f32[2, 3][3, 1]cpu" = torch.ops.aten.add.Tensor(slice_scatter_4, slice_scatter_5);  slice_scatter_4 = slice_scatter_5 = None
    return (add_4,)


@torch.compile(backend='inductor', fullgraph=True)
def j(x):
    return forward(x)


x = torch.randn(2, 3)
print(j(x))
# (tensor([[0., 0., 1.],
#        [0., 0., 1.]]),)

@guilhermeleobas guilhermeleobas added module: inductor and removed dynamo-functorch Issues related to dynamo/compile on functorch transforms labels Nov 13, 2024
@anatolvitold
Copy link
Author

@masnesral and @guilhermeleobas, Could you please assign someone else to this ticket, so it will not fall through the cracks?

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov @rec

@masnesral
Copy link
Contributor

Could you please assign someone else to this ticket, so it will not fall through the cracks?

I think the fact that it has "module: inductor" means it's being tracked and is on the backlog, but cc @eellison

@eellison
Copy link
Contributor

eellison commented Dec 2, 2024

(Apparent) root cause here is reinplace_inplaceable_ops.

Binary search completed for inductor - post_grad_passes. The bisect number is 8. Debug info: reinplace_inplaceable_ops

cc @Chillee.

repros on cuda and cpu.

@Chillee
Copy link
Collaborator

Chillee commented Dec 3, 2024

cc: @zou3519

@eellison
Copy link
Contributor

eellison commented Dec 4, 2024

cc @laithsakka , @Chillee , @zou3519 who would be the appropriate owner of reinplace_inplaceable_ops that can look into this ? would be great to find owner since this is high pri issue/silent incorrectness

@zou3519
Copy link
Contributor

zou3519 commented Dec 4, 2024

Laith and I only really know our way around the reinplacing logic for custom operators and user-defined triton kernels. Reinplacing for regular built-in pytorch operators goes down another path

@eellison
Copy link
Contributor

eellison commented Dec 4, 2024

okay assigning to @Chillee as owner of that pass

@mlazos mlazos added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 14, 2025
@Chillee
Copy link
Collaborator

Chillee commented Feb 3, 2025

Taking a look at this now.

@Danielmic
Copy link
Contributor

hi @Chillee Any update on this?

@mlazos mlazos assigned mlazos and unassigned Chillee Mar 18, 2025
@zou3519 zou3519 added module: correctness (silent) issue that returns an incorrect result silently actionable labels Apr 10, 2025
@zou3519 zou3519 self-assigned this Apr 22, 2025
@zou3519
Copy link
Contributor

zou3519 commented Apr 23, 2025

Got a smaller repro:

import torch

torch.manual_seed(0)
x = torch.randn(2, 3)

def forward(x_1):
    # ones
    a = torch.ones([2, 3])
    c = torch.ones(2)
    a[:, 0].copy_(c)

    d = a.clone()
    e = torch.ops.aten.as_strided.default(d, [2], [3], 0)
    f = e.clone()

    g = torch.zeros(2)
    e.copy_(g)

    h = torch.zeros(2, 3)
    h[:, 0].copy_(f)

    add_1 = d + h
    return add_1

print(forward(x))
print(torch.compile(forward, fullgraph=True, backend="inductor")(x))

@zou3519
Copy link
Contributor

zou3519 commented Apr 24, 2025

I know what is wrong, I just don't know how to fix it. Here's my attempt at fixing it (which appears to fix the test case, but breaks the rest of inductor: #152011).

The problem is the reinplacing pass is actually 3 graph passes. The first one, canonicalize_view_scatter_ops, adds new nodes into the graph. The second one, reinplace_inplaceable_ops, requires accurate aliasing information on the FakeTensors.

canonicalize_view_scatter_ops does add accurate FakeTensor information to the new nodes it is adding, but it does not update other nodes that depend on the new nodes.

As so the fix I tried was to run FakeTensorUpdator in between those two passes, but it didn't quite work.

@mlazos mlazos removed their assignment Apr 24, 2025
@zou3519 zou3519 added the ubn "unbreak now", our utmost priority label. label May 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
actionable high priority module: correctness (silent) issue that returns an incorrect result silently module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module ubn "unbreak now", our utmost priority label.
Projects
None yet
Development

No branches or pull requests

8 participants