Thanks to visit codestin.com
Credit goes to github.com

Skip to content

cuda graphs produce two additional kernel calls #143572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
trporn opened this issue Dec 19, 2024 · 4 comments
Closed

cuda graphs produce two additional kernel calls #143572

trporn opened this issue Dec 19, 2024 · 4 comments
Labels
module: cuda graphs Ability to capture and then replay streams of CUDA kernels triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@trporn
Copy link

trporn commented Dec 19, 2024

🐛 Describe the bug

When using cuda graph capture, the replay() function produces two additional kernel calls before the launchGraph.
Additional calls are to fillFunctor, probably the result of replay_prologue(), line 229 in https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/cuda/CUDAGraph.cpp.
This is unexpected behavior that makes graphs a non-viable option for smaller code sections.
Use nsys profile -t cuda python file.py
on the following code to see the problem.

import torch

N, D_in, H, D_out = 640, 4096, 2048, 1024
model = torch.nn.Linear(D_in, H).cuda()
# Placeholders used for capture
static_input = torch.randn(N, D_in, device='cuda')
static_target = torch.randn(N, D_out, device='cuda')

# warmup
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for i in range(3):
        y_pred = model(static_input)
torch.cuda.current_stream().wait_stream(s)

# capture
g = torch.cuda.CUDAGraph()

with torch.cuda.graph(g):
    static_y_pred = model(static_input)
    
real_inputs = [torch.rand_like(static_input) for _ in range(100)]
real_targets = [torch.rand_like(static_target) for _ in range(100)]

for data, target in zip(real_inputs, real_targets):
    static_input.copy_(data)
    static_target.copy_(target)
    g.replay()

Versions

Collecting environment information...
PyTorch version: 2.4.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.11.0 (main, Mar 1 2023, 18:26:19) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.5.40
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX 3500 Ada Generation Laptop GPU
Nvidia driver version: 553.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 20
On-line CPU(s) list: 0-19
Vendor ID: GenuineIntel
Model name: 13th Gen Intel(R) Core(TM) i7-13800H
CPU family: 6
Model: 186
Thread(s) per core: 2
Core(s) per socket: 10
Socket(s): 1
Stepping: 2
BogoMIPS: 5836.79
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 480 KiB (10 instances)
L1i cache: 320 KiB (10 instances)
L2 cache: 12.5 MiB (10 instances)
L3 cache: 24 MiB (1 instance)
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] nvtx==0.2.10
[pip3] pytorch-lightning==2.4.0
[pip3] torch==2.4.1
[pip3] torchmetrics==1.6.0
[pip3] triton==3.0.0
[conda] numpy 1.26.4 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.2.106 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.20.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.6.85 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.1.105 pypi_0 pypi
[conda] nvtx 0.2.10 pypi_0 pypi
[conda] pytorch-lightning 2.4.0 pypi_0 pypi
[conda] torch 2.4.1 pypi_0 pypi
[conda] torchmetrics 1.6.0 pypi_0 pypi
[conda] triton 3.0.0 pypi_0 pypi

cc @mcarilli @ezyang @eellison @penguinwu

@ngimel
Copy link
Collaborator

ngimel commented Dec 19, 2024

This is coming from setting random seeds, without that cuda graph cannot guarantee the correct behavior if the graphed function uses random numbers. Perhaps we could add an option to disable this if user says that random seeds are not required.

@ngimel ngimel added the module: cuda graphs Ability to capture and then replay streams of CUDA kernels label Dec 19, 2024
@eellison
Copy link
Contributor

We should at least skip the increment if wholegraph_increments is 0.

@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 20, 2024
@trporn
Copy link
Author

trporn commented Dec 23, 2024

Thanks for this.
If the prologue is just for the random state, it would be good to give the user the choice. As it is right now in inference, even big models are cpu bound and graphing (sometimes small) submodules is the only pytorch accessible way to reduce this overhead. I might just compile an extension with these lines removed.

@ngimel
Copy link
Collaborator

ngimel commented Apr 30, 2025

Fixed in #143777

@ngimel ngimel closed this as completed Apr 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda graphs Ability to capture and then replay streams of CUDA kernels triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants