Thanks to visit codestin.com
Credit goes to github.com

Skip to content

conv3d bfloat16 wrong result #163539

@ZejiaZheng

Description

@ZejiaZheng

🐛 Describe the bug

conv3d is producing wrong results in large input shapes after the cudnn upgrade (https://github.com/pytorch/pytorch/pull/155122/commits) // Update: issue is from cudnn 9.10.2
To repro

import torch
import torch.nn as nn
import torch.nn.functional as F


def test_conv2d_bf16_vs_fp32():
    """Test case for large spatial dimension conv2d with bf16 vs fp32"""

    # Check for CUDA availability
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Set up the problematic configuration
    batch_size = 1
    in_channels = 16
    out_channels = 16
    in_height, in_width = 1282, 722
    out_height, out_width = 1280, 720
    kernel_size = (1, 3, 3)
    stride = (1, 1, 1)
    dilation = (1, 1, 1)

    # Note: This is a 3D convolution based on kernel/stride/dilation having 3 values
    # Adjusting dimensions accordingly
    in_depth = 124
    out_depth = 124

    # Create random input on CUDA
    torch.manual_seed(42)
    x_fp32 = torch.randn(batch_size, in_channels, in_depth, in_height, in_width,
                        dtype=torch.float32, device=device)
    x_bf16 = x_fp32.to(torch.bfloat16)

    # Create conv layer with random weights
    conv = nn.Conv3d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=0,  # No padding to match input/output sizes
        dilation=dilation,
        bias=False
    ).to(device)

    # Initialize weights
    nn.init.kaiming_normal_(conv.weight)

    # Forward pass in fp32
    conv_fp32 = conv.to(torch.float32)
    with torch.no_grad():
        y_fp32 = conv_fp32(x_fp32)

    # Forward pass in bf16
    conv_bf16 = conv.to(torch.bfloat16)
    with torch.no_grad():
        y_bf16 = conv_bf16(x_bf16)

    # Convert bf16 output to fp32 for comparison
    y_bf16_as_fp32 = y_bf16.to(torch.float32)

    # Compute differences
    diff = torch.abs(y_fp32 - y_bf16_as_fp32)
    mean_abs_diff = diff.mean().item()
    l2_diff = torch.sqrt((diff ** 2).sum()).item()
    l2_y_fp32 = torch.sqrt((y_fp32 ** 2).sum()).item()
    rel_l2 = l2_diff / l2_y_fp32 if l2_y_fp32 > 0 else float('inf')
    max_abs_diff = diff.max().item()

    # Print results
    print(f"Configuration:")
    print(f"  dtype_in=torch.bfloat16")
    print(f"  kernel={kernel_size} stride={stride} dilation={dilation}")
    print(f"  x.shape={tuple(x_bf16.shape)}  y.shape={tuple(y_bf16.shape)}")
    print(f"  diff: mean|Δ|={mean_abs_diff:.6f}  L2={l2_diff:.6f}  relL2={rel_l2:.6f}  max|Δ|={max_abs_diff:.6f}")

    # Check if error is concerning (> 50% relative error)
    if rel_l2 > 0.5:
        print(f"\n⚠️  WARNING: Relative L2 error is {rel_l2:.1%} - this is extremely high!")
    print("\nLast 20 elements of outputs for manual inspection:")
    print(y_fp32.flatten()[-20:])
    print(y_bf16_as_fp32.flatten()[-20:])

    return {
        'mean_abs_diff': mean_abs_diff,
        'l2_diff': l2_diff,
        'rel_l2': rel_l2,
        'max_abs_diff': max_abs_diff,
        'y_fp32_shape': tuple(y_fp32.shape),
        'y_bf16_shape': tuple(y_bf16.shape)
    }

if __name__ == "__main__":
    print("Testing 3D Convolution (based on 3-element kernel/stride/dilation):")
    results_3d = test_conv2d_bf16_vs_fp32()

On current torch 2.8.0, output is:

Testing 3D Convolution (based on 3-element kernel/stride/dilation):
Using device: cuda
Configuration:
  dtype_in=torch.bfloat16
  kernel=(1, 3, 3) stride=(1, 1, 1) dilation=(1, 1, 1)
  x.shape=(1, 16, 124, 1282, 722)  y.shape=(1, 16, 124, 1280, 720)
  diff: mean|Δ|=0.649738  L2=54093.402344  relL2=0.909543  max|Δ|=12.276418

⚠️  WARNING: Relative L2 error is 91.0% - this is extremely high!

Last 20 elements of outputs for manual inspection:
tensor([ 2.0922,  0.1431,  1.3097,  1.7380, -0.2731,  1.2708,  1.4640,  1.5003,
         1.1057, -0.2895,  0.8647,  2.7035, -0.8348,  1.3429, -2.3612,  1.3541,
         1.8041,  0.6064,  1.7154,  0.5617], device='cuda:0')
tensor([ 0.6992, -1.1797, -0.5195,  0.5898, -1.1328,  1.0781,  1.3125, -0.4102,
         2.8125,  0.2520,  0.9883, -0.9648,  0.4043, -3.1562,  2.4844,  1.5547,
        -0.2451,  0.3145, -0.1484,  1.1875], device='cuda:0')

the last 20 elements are completely wrong in some elements.

On previous torch version 2.8.0.dev20250527+cu126, the test was passing :

Testing 3D Convolution (based on 3-element kernel/stride/dilation):
Using device: cuda
Configuration:
  dtype_in=torch.bfloat16
  kernel=(1, 3, 3) stride=(1, 1, 1) dilation=(1, 1, 1)
  x.shape=(1, 16, 124, 1282, 722)  y.shape=(1, 16, 124, 1280, 720)
  diff: mean|Δ|=0.003143  L2=171.637970  relL2=0.002886  max|Δ|=0.033743

Last 20 elements of outputs for manual inspection:
tensor([ 2.0922,  0.1431,  1.3097,  1.7380, -0.2731,  1.2708,  1.4640,  1.5003,
         1.1057, -0.2895,  0.8647,  2.7035, -0.8348,  1.3429, -2.3612,  1.3541,
         1.8041,  0.6064,  1.7154,  0.5617], device='cuda:0')
tensor([ 2.0938,  0.1416,  1.3125,  1.7344, -0.2754,  1.2656,  1.4609,  1.5000,
         1.1094, -0.2871,  0.8672,  2.7031, -0.8320,  1.3359, -2.3750,  1.3516,
         1.7969,  0.6055,  1.7109,  0.5586], device='cuda:0')

Versions

Collecting environment information...
PyTorch version: 2.8.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.10.12 (main, Aug 15 2025, 14:32:43) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.1.100+-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.9.86
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 535.183.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.10.2
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        52 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               208
On-line CPU(s) list:                  0-207
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Xeon(R) Platinum 8481C CPU @ 2.70GHz
CPU family:                           6
Model:                                143
Thread(s) per core:                   2
Core(s) per socket:                   52
Socket(s):                            2
Stepping:                             8
BogoMIPS:                             5399.99
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rtm avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid cldemote movdiri movdir64b fsrm md_clear serialize amx_bf16 avx512_fp16 amx_tile amx_int8 arch_capabilities
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            4.9 MiB (104 instances)
L1i cache:                            3.3 MiB (104 instances)
L2 cache:                             208 MiB (104 instances)
L3 cache:                             210 MiB (2 instances)
NUMA node(s):                         2
NUMA node0 CPU(s):                    0-51,104-155
NUMA node1 CPU(s):                    52-103,156-207
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] alias-free-torch==0.0.6
[pip3] ema-pytorch==0.7.7
[pip3] facenet-pytorch==2.5.1
[pip3] guided-filter-pytorch==3.7.5
[pip3] lovely-numpy==0.2.13
[pip3] mypy==1.15.0
[pip3] mypy-boto3-s3==1.37.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.3
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] nvtx==0.2.8
[pip3] onnx==1.18.0
[pip3] onnx-graphsurgeon==0.5.2
[pip3] onnx-ir==0.1.3
[pip3] onnxruntime==1.22.0
[pip3] onnxruntime-gpu==1.21.0
[pip3] onnxscript==0.3.1
[pip3] open-clip-torch==2.22.0
[pip3] pytorch-lightning==2.2.4
[pip3] pytorch-msssim==1.0.0
[pip3] pytorchvideo==0.1.5
[pip3] rw_torch_extensions==0.1.6
[pip3] torch==2.8.0
[pip3] torch-fidelity==0.3.0
[pip3] torch-stoi==0.2.3
[pip3] torch-tb-profiler==0.4.1
[pip3] torch_tensorrt==2.8.0
[pip3] torchao==0.12.0
[pip3] torchaudio==2.8.0
[pip3] torchcodec==0.6.0
[pip3] torchdata==0.11.0
[pip3] torchdiffeq==0.2.5
[pip3] torchft-nightly==2025.7.27
[pip3] torchlibrosa==0.1.0
[pip3] torchmetrics==1.3.1
[pip3] torchray==1.0.0.2
[pip3] torchvision==0.23.0
[pip3] triton==3.4.0
[pip3] vector-quantize-pytorch==1.20.11
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @csarofeen @ptrblck @xwang233 @eqy @jerryzh168

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: correctness (silent)issue that returns an incorrect result silentlymodule: cudaRelated to torch.cuda, and CUDA support in generalmodule: cudnnRelated to torch.backends.cudnn, and CuDNN supportmodule: regressionIt used to work, and now it doesn'ttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions