🐛 Describe the bug
conv3d is producing wrong results in large input shapes after the cudnn upgrade (https://github.com/pytorch/pytorch/pull/155122/commits) // Update: issue is from cudnn 9.10.2
To repro
import torch
import torch.nn as nn
import torch.nn.functional as F
def test_conv2d_bf16_vs_fp32():
"""Test case for large spatial dimension conv2d with bf16 vs fp32"""
# Check for CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Set up the problematic configuration
batch_size = 1
in_channels = 16
out_channels = 16
in_height, in_width = 1282, 722
out_height, out_width = 1280, 720
kernel_size = (1, 3, 3)
stride = (1, 1, 1)
dilation = (1, 1, 1)
# Note: This is a 3D convolution based on kernel/stride/dilation having 3 values
# Adjusting dimensions accordingly
in_depth = 124
out_depth = 124
# Create random input on CUDA
torch.manual_seed(42)
x_fp32 = torch.randn(batch_size, in_channels, in_depth, in_height, in_width,
dtype=torch.float32, device=device)
x_bf16 = x_fp32.to(torch.bfloat16)
# Create conv layer with random weights
conv = nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0, # No padding to match input/output sizes
dilation=dilation,
bias=False
).to(device)
# Initialize weights
nn.init.kaiming_normal_(conv.weight)
# Forward pass in fp32
conv_fp32 = conv.to(torch.float32)
with torch.no_grad():
y_fp32 = conv_fp32(x_fp32)
# Forward pass in bf16
conv_bf16 = conv.to(torch.bfloat16)
with torch.no_grad():
y_bf16 = conv_bf16(x_bf16)
# Convert bf16 output to fp32 for comparison
y_bf16_as_fp32 = y_bf16.to(torch.float32)
# Compute differences
diff = torch.abs(y_fp32 - y_bf16_as_fp32)
mean_abs_diff = diff.mean().item()
l2_diff = torch.sqrt((diff ** 2).sum()).item()
l2_y_fp32 = torch.sqrt((y_fp32 ** 2).sum()).item()
rel_l2 = l2_diff / l2_y_fp32 if l2_y_fp32 > 0 else float('inf')
max_abs_diff = diff.max().item()
# Print results
print(f"Configuration:")
print(f" dtype_in=torch.bfloat16")
print(f" kernel={kernel_size} stride={stride} dilation={dilation}")
print(f" x.shape={tuple(x_bf16.shape)} y.shape={tuple(y_bf16.shape)}")
print(f" diff: mean|Δ|={mean_abs_diff:.6f} L2={l2_diff:.6f} relL2={rel_l2:.6f} max|Δ|={max_abs_diff:.6f}")
# Check if error is concerning (> 50% relative error)
if rel_l2 > 0.5:
print(f"\n⚠️ WARNING: Relative L2 error is {rel_l2:.1%} - this is extremely high!")
print("\nLast 20 elements of outputs for manual inspection:")
print(y_fp32.flatten()[-20:])
print(y_bf16_as_fp32.flatten()[-20:])
return {
'mean_abs_diff': mean_abs_diff,
'l2_diff': l2_diff,
'rel_l2': rel_l2,
'max_abs_diff': max_abs_diff,
'y_fp32_shape': tuple(y_fp32.shape),
'y_bf16_shape': tuple(y_bf16.shape)
}
if __name__ == "__main__":
print("Testing 3D Convolution (based on 3-element kernel/stride/dilation):")
results_3d = test_conv2d_bf16_vs_fp32()
On current torch 2.8.0, output is:
Testing 3D Convolution (based on 3-element kernel/stride/dilation):
Using device: cuda
Configuration:
dtype_in=torch.bfloat16
kernel=(1, 3, 3) stride=(1, 1, 1) dilation=(1, 1, 1)
x.shape=(1, 16, 124, 1282, 722) y.shape=(1, 16, 124, 1280, 720)
diff: mean|Δ|=0.649738 L2=54093.402344 relL2=0.909543 max|Δ|=12.276418
⚠️ WARNING: Relative L2 error is 91.0% - this is extremely high!
Last 20 elements of outputs for manual inspection:
tensor([ 2.0922, 0.1431, 1.3097, 1.7380, -0.2731, 1.2708, 1.4640, 1.5003,
1.1057, -0.2895, 0.8647, 2.7035, -0.8348, 1.3429, -2.3612, 1.3541,
1.8041, 0.6064, 1.7154, 0.5617], device='cuda:0')
tensor([ 0.6992, -1.1797, -0.5195, 0.5898, -1.1328, 1.0781, 1.3125, -0.4102,
2.8125, 0.2520, 0.9883, -0.9648, 0.4043, -3.1562, 2.4844, 1.5547,
-0.2451, 0.3145, -0.1484, 1.1875], device='cuda:0')
the last 20 elements are completely wrong in some elements.
On previous torch version 2.8.0.dev20250527+cu126, the test was passing :
Testing 3D Convolution (based on 3-element kernel/stride/dilation):
Using device: cuda
Configuration:
dtype_in=torch.bfloat16
kernel=(1, 3, 3) stride=(1, 1, 1) dilation=(1, 1, 1)
x.shape=(1, 16, 124, 1282, 722) y.shape=(1, 16, 124, 1280, 720)
diff: mean|Δ|=0.003143 L2=171.637970 relL2=0.002886 max|Δ|=0.033743
Last 20 elements of outputs for manual inspection:
tensor([ 2.0922, 0.1431, 1.3097, 1.7380, -0.2731, 1.2708, 1.4640, 1.5003,
1.1057, -0.2895, 0.8647, 2.7035, -0.8348, 1.3429, -2.3612, 1.3541,
1.8041, 0.6064, 1.7154, 0.5617], device='cuda:0')
tensor([ 2.0938, 0.1416, 1.3125, 1.7344, -0.2754, 1.2656, 1.4609, 1.5000,
1.1094, -0.2871, 0.8672, 2.7031, -0.8320, 1.3359, -2.3750, 1.3516,
1.7969, 0.6055, 1.7109, 0.5586], device='cuda:0')
Versions
Collecting environment information...
PyTorch version: 2.8.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.10.12 (main, Aug 15 2025, 14:32:43) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.1.100+-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.9.86
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3
Nvidia driver version: 535.183.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.10.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.10.2
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 208
On-line CPU(s) list: 0-207
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8481C CPU @ 2.70GHz
CPU family: 6
Model: 143
Thread(s) per core: 2
Core(s) per socket: 52
Socket(s): 2
Stepping: 8
BogoMIPS: 5399.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rtm avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid cldemote movdiri movdir64b fsrm md_clear serialize amx_bf16 avx512_fp16 amx_tile amx_int8 arch_capabilities
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 4.9 MiB (104 instances)
L1i cache: 3.3 MiB (104 instances)
L2 cache: 208 MiB (104 instances)
L3 cache: 210 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-51,104-155
NUMA node1 CPU(s): 52-103,156-207
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] alias-free-torch==0.0.6
[pip3] ema-pytorch==0.7.7
[pip3] facenet-pytorch==2.5.1
[pip3] guided-filter-pytorch==3.7.5
[pip3] lovely-numpy==0.2.13
[pip3] mypy==1.15.0
[pip3] mypy-boto3-s3==1.37.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.3
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] nvtx==0.2.8
[pip3] onnx==1.18.0
[pip3] onnx-graphsurgeon==0.5.2
[pip3] onnx-ir==0.1.3
[pip3] onnxruntime==1.22.0
[pip3] onnxruntime-gpu==1.21.0
[pip3] onnxscript==0.3.1
[pip3] open-clip-torch==2.22.0
[pip3] pytorch-lightning==2.2.4
[pip3] pytorch-msssim==1.0.0
[pip3] pytorchvideo==0.1.5
[pip3] rw_torch_extensions==0.1.6
[pip3] torch==2.8.0
[pip3] torch-fidelity==0.3.0
[pip3] torch-stoi==0.2.3
[pip3] torch-tb-profiler==0.4.1
[pip3] torch_tensorrt==2.8.0
[pip3] torchao==0.12.0
[pip3] torchaudio==2.8.0
[pip3] torchcodec==0.6.0
[pip3] torchdata==0.11.0
[pip3] torchdiffeq==0.2.5
[pip3] torchft-nightly==2025.7.27
[pip3] torchlibrosa==0.1.0
[pip3] torchmetrics==1.3.1
[pip3] torchray==1.0.0.2
[pip3] torchvision==0.23.0
[pip3] triton==3.4.0
[pip3] vector-quantize-pytorch==1.20.11
[conda] Could not collect
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @csarofeen @ptrblck @xwang233 @eqy @jerryzh168
🐛 Describe the bug
conv3d is producing wrong results in large input shapes after the cudnn upgrade (
https://github.com/pytorch/pytorch/pull/155122/commits) // Update: issue is from cudnn 9.10.2To repro
On current torch 2.8.0, output is:
the last 20 elements are completely wrong in some elements.
On previous torch version 2.8.0.dev20250527+cu126, the test was passing :
Versions
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @csarofeen @ptrblck @xwang233 @eqy @jerryzh168