Fix CUDA Graph Capture Error in Tensor Argument Extraction #197
+21
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Problem
When tracing kernel launches during CUDA graph capture (used by
triton.testing.do_bench_cudagraph), TritonParse's tensor argument extraction caused crashes with:The error occurred because
str(tensor_value)inextract_arg_info()triggered PyTorch's__repr__method, which accesses tensor data. CUDA graph capture forbids certain memory operations, causing the entire capture to be invalidated.Root Cause
In
add_launch_metadata(), the code unconditionally calledextract_arg_info()to collect detailed tensor metadata. During CUDA graph capture:str(arg_value)triggerstorch.Tensor.__repr__()__repr__()callstorch.masked_select()andtorch.isfinite()to format tensor datacudaErrorStreamCaptureUnsupported, invalidating the entire graphSolution
Detect and skip argument extraction during CUDA graph capture:
add_launch_metadata(), checktorch.cuda.is_current_stream_capturing()before callingextract_arg_info()extract_arg_info()since it's never called during captureKey changes:
Impact
triton.testing.do_bench_cudagraphno longer crashesis_current_stream_capturing()Testing
Verified with Triton's Hopper GEMM benchmark (
hopper-gemm-ws_test.py) which usesdo_bench_cudagraphextensively. Previously failed with CUDA graph capture errors, now runs successfully.Related Issues
Fixes: meta-pytorch/tritonbench#632