Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@FindHao
Copy link
Member

@FindHao FindHao commented Nov 11, 2025

Problem

When tracing kernel launches during CUDA graph capture (used by triton.testing.do_bench_cudagraph), TritonParse's tensor argument extraction caused crashes with:

torch.AcceleratorError: CUDA error: operation not permitted when stream is capturing

The error occurred because str(tensor_value) in extract_arg_info() triggered PyTorch's __repr__ method, which accesses tensor data. CUDA graph capture forbids certain memory operations, causing the entire capture to be invalidated.

Root Cause

In add_launch_metadata(), the code unconditionally called extract_arg_info() to collect detailed tensor metadata. During CUDA graph capture:

  1. str(arg_value) triggers torch.Tensor.__repr__()
  2. __repr__() calls torch.masked_select() and torch.isfinite() to format tensor data
  3. These operations are prohibited during CUDA graph capture
  4. CUDA raises cudaErrorStreamCaptureUnsupported, invalidating the entire graph

Solution

Detect and skip argument extraction during CUDA graph capture:

  1. In add_launch_metadata(), check torch.cuda.is_current_stream_capturing() before calling extract_arg_info()
  2. If capturing, return minimal metadata with a note that extraction was skipped
  3. Remove redundant try-except blocks in extract_arg_info() since it's never called during capture

Key changes:

def add_launch_metadata(grid, metadata, arg_dict, inductor_args=None):
    # Check if we're in CUDA graph capture mode
    is_capturing = False
    if TORCH_INSTALLED:
        try:
            is_capturing = torch.cuda.is_current_stream_capturing()
        except (AttributeError, RuntimeError):
            pass  # Handle API unavailability in older PyTorch versions
    
    if is_capturing:
        # Return minimal metadata without argument extraction
        return {
            "launch_metadata_tritonparse": (
                grid,
                metadata._asdict(),
                {"_note": "argument extraction skipped during CUDA graph capture"},
                {},
            )
        }
    
    # Normal path: extract detailed argument information
    extracted_args = extract_arg_info(arg_dict)
    ...

Impact

  • Fixes: Kernel benchmarking with triton.testing.do_bench_cudagraph no longer crashes
  • Preserves: Full argument extraction for normal (non-capture) kernel launches
  • Safe: Gracefully handles older PyTorch versions that lack is_current_stream_capturing()
  • Clean: Removes redundant defensive code since capture detection happens at the entry point

Testing

Verified with Triton's Hopper GEMM benchmark (hopper-gemm-ws_test.py) which uses do_bench_cudagraph extensively. Previously failed with CUDA graph capture errors, now runs successfully.

Related Issues

Fixes: meta-pytorch/tritonbench#632

During CUDA graph capture, accessing tensor data (via str()) triggers
cudaErrorStreamCaptureUnsupported, invalidating the entire capture.

Changes:
- Check torch.cuda.is_current_stream_capturing() in add_launch_metadata()
- Skip argument extraction during capture and return minimal metadata
- Remove redundant try-except blocks in extract_arg_info()

Fixes: meta-pytorch/tritonbench#632
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 11, 2025
@meta-codesync
Copy link

meta-codesync bot commented Nov 11, 2025

@FindHao has imported this pull request. If you are a Meta employee, you can view this in D86722827.

@meta-codesync
Copy link

meta-codesync bot commented Nov 12, 2025

@FindHao merged this pull request in fb7197b.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot. Merged

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ncu error

4 participants