Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[CUDAEvent.h] support external cuda events in cudagraphs #146145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: gh/nmacchioni/39/base
Choose a base branch
from

Conversation

nmacchioni
Copy link
Contributor

@nmacchioni nmacchioni commented Jan 31, 2025

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Jan 31, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146145

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 0242089 with merge base 72405b0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

nmacchioni added a commit that referenced this pull request Jan 31, 2025
ghstack-source-id: c3e7442
Pull Request resolved: #146145
@nmacchioni nmacchioni added the release notes: cuda release notes category label Jan 31, 2025
[ghstack-poisoned]
nmacchioni added a commit that referenced this pull request Feb 4, 2025
ghstack-source-id: e24100b
Pull Request resolved: #146145
@galv
Copy link
Collaborator

galv commented Feb 4, 2025

Hello @nmacchioni your current failure is due to this line:

s1.wait_stream(s0)

Basically, that is doing the following:

cudaEventRecordWithFlags(e, s0, cudaEventWaitExternal);
cudaStreamWaitEvent(s1, e);

The proper way to create a cross-stream dependency is described here: https://docs.nvidia.com/cuda/cuda-c-programming-guide/#cross-stream-dependencies-and-events

While it is obscure, cudaEventWaitExternal is documented here: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html#group__CUDART__TYPES_1g0c23426b7252eaa9cef695859991304e

Event is captured in the graph as an external event node when performing stream capture

Basically, it means that a cuda event record node is created in the cuda graph in your case, rather than a stream dependency being created (and no cuda event record node).

And this line has problems for similar reasons:

s0.wait_stream(s1)

My impression is that you need to make a new API, if you want to use external events like this. Cuda stream capture made a decision long ago to use events for internal dependency ordering long ago to make stream capture of existing workloads to cuda graphs easy, but it unfortunately does not cooperate together with the "external" concept of events you are using here.

[ghstack-poisoned]
nmacchioni added a commit that referenced this pull request Feb 5, 2025
ghstack-source-id: 03eca54
Pull Request resolved: #146145
[ghstack-poisoned]
nmacchioni added a commit that referenced this pull request Feb 5, 2025
ghstack-source-id: 031c62a
Pull Request resolved: #146145
[ghstack-poisoned]
nmacchioni added a commit that referenced this pull request Feb 6, 2025
ghstack-source-id: f26e281
Pull Request resolved: #146145
@nmacchioni nmacchioni requested a review from ngimel February 6, 2025 00:05
[ghstack-poisoned]
nmacchioni added a commit that referenced this pull request Feb 6, 2025
ghstack-source-id: 78538a2
Pull Request resolved: #146145
[ghstack-poisoned]
nmacchioni added a commit that referenced this pull request Feb 6, 2025
ghstack-source-id: 8d57650
Pull Request resolved: #146145
@nmacchioni
Copy link
Contributor Author

Thank you for the feedback @galv!

My impression is that you need to make a new API, if you want to use external events like this. Cuda stream capture made a decision long ago to use events for internal dependency ordering long ago to make stream capture of existing workloads to cuda graphs easy, but it unfortunately does not cooperate together with the "external" concept of events you are using here.

I think this is what @ngimel and I have also come to conclude. It is unfortunate that we can't reuse the enable_timing=True API at this moment, given we have no control over how users do things at the moment (i.e. a user could, technically set enable_timing=True, then record the event, then wait on that event, and never choose to use the timing itself).

For now, I've added timing_only=True as an option which should allow us to differentiate between the two types of events, I would appreciate any feedback on the implementation.

Although timing_only=True should unblock my other goals for now, I'm thinking we might also want to start warning users who set enable_timing=True and then proceed to use those events for intra-graph synchronization. I'm not sure what the deprecation cycle looks like for something like this, but I'd hope that warning would lead to some future deprecation of enable_timing=True constructed events being used for intra-graph synchronization so that we could collapse the API back down.

[ghstack-poisoned]
nmacchioni added a commit that referenced this pull request Feb 6, 2025
ghstack-source-id: 32b9b46
Pull Request resolved: #146145
* `cudaEventRecordExternal`; `cudaEventTimingOnly` enables the distinction between these
* two use cases. `cudaEventEnableTiming` must be set in conjunction with `cudaEventTimingOnly`.
*/
#define cudaEventTimingOnly 0x05
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is correct, flags are bits, and 1, 2, 4 are already taken, so your next option is 8, not 5, and even that is not future proof - what if nvidia decides to add some other flags that will conflict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes you're totally right, changing to 0x08. I'm not sure how we could be 100% future proof here, but it should be simple enough to update if nvidia does add new flags. Also I'm hoping we eventually consolidate the API back to the original form, through some form of deprecation cycle, and so this won't be around forever.

@@ -163,12 +163,15 @@ class Event(torch._C._CudaEventBase):
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html
"""

def __new__(cls, enable_timing=False, blocking=False, interprocess=False):
def __new__(
cls, enable_timing=False, blocking=False, interprocess=False, timing_only=False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modify the doc string with the new arg

[ghstack-poisoned]
nmacchioni added a commit that referenced this pull request Feb 6, 2025
ghstack-source-id: 6c6625b
Pull Request resolved: #146145
/*
* `cudaEventTimingOnly` is a torch-specific flag that is used to indicate that
* the CUDAEvent will only be used for timing, and never for synchronization.
* CUDAEvents used for intra-graph timing must be recorded with `cudaEventRecordExternal`,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • CUDAEvents used for intra-graph timing must be recorded with cudaEventRecordExternal,
  • whereas CUDAEvents used for inter-graph synchronization must never be recorded with

I believe you need to swap "intra-graph" and "inter-graph" in this comment. "intra-graph" means within the graph (i.e., the meaning in https://docs.nvidia.com/cuda/cuda-c-programming-guide/#cross-stream-dependencies-and-events)

@galv
Copy link
Collaborator

galv commented Feb 7, 2025

I think this is what @ngimel and I have also come to conclude. It is unfortunate that we can't reuse the enable_timing=True API at this moment, given we have no control over how users do things at the moment (i.e. a user could, technically set enable_timing=True, then record the event, then wait on that event, and never choose to use the timing itself).

That makes sense. However, can I recommend that you change the "timing_only" argument to "external"? It is a bit niche, and no one has ever needed to do this in pytorch, but you can record an "external" event within a cuda graph in order to signal to some other stream or CPU code waiting on that event that the previous part of that cuda graph is done (without having to wait until the whole cuda graph is done). (You can also have part of a cuda graph "wait" on work done by work outside that cuda graph by waiting on an event).

You would then have to remove this check: https://github.com/pytorch/pytorch/pull/146145/files#diff-3fb8ef31de1e7a14456d5edf1c252a28fa9643f71586cb5ca2ad6a332de9d861R44-R47

Does that make sense? While no one is using this functionality right now, this naming is more future proof for when someone might want to do what I am describing.

@ngimel
Copy link
Collaborator

ngimel commented Feb 7, 2025

@galv that's a great idea, @nmacchioni let's go with external

@nmacchioni
Copy link
Contributor Author

Yeah I'm cool with changing it to "external"

That makes sense. However, can I recommend that you change the "timing_only" argument to "external"? It is a bit niche, and no one has ever needed to do this in pytorch, but you can record an "external" event within a cuda graph in order to signal to some other stream or CPU code waiting on that event that the previous part of that cuda graph is done (without having to wait until the whole cuda graph is done). (You can also have part of a cuda graph "wait" on work done by work outside that cuda graph by waiting on an event).

I kind of understand this, but could you provide a quick code example? I want to make sure I'm understanding correctly. Thank you!

@ngimel
Copy link
Collaborator

ngimel commented Feb 8, 2025

The code example would be using this event after the graph, and ordering the work correctly wrt to work inside the graph. Similarly, we should also do cudaStreamWaitEvent with cudaEventWaitExternal for external wait events, for a mirror situation, where you record even externally, and then want the work in graph be ordered wrt to this event

@galv
Copy link
Collaborator

galv commented Feb 9, 2025

I kind of understand this, but could you provide a quick code example? I want to make sure I'm understanding correctly. Thank you!

@nmacchioni no problem. It took me some time and I tried to simplify as much as possible but here is one. Note that it may have some errors (I didn't run this).

import torch

def foo():
    iter_finished_event = torch.cuda.Event(external=torch.cuda.is_current_stream_capturing())
    llm_output_tokens_gpu = do_llm_inference()
    llm_output_tokens_cpu = llm_output_tokens_gpu.to(device=torch.device("cpu")).pin_memory()
    iter_finished_event.record() # This will use use cudaEventRecordExternal in cudaEventRecordWithFlags
    return iter_finished_event, llm_output_tokens_cpu

keep_generating = True

g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
    iter_finished_event, llm_output_tokens_cpu = foo()

cv = threading.Condition()

def cpu_thread(event, llm_output_tokens_cpu):
    event.synchronize()
    send_to_client(llm_output_tokens_cpu)
    if llm_output_tokens_cpu[-1] == eos_token:
        keep_generating = False
    cv.notify()

thread = threading.Thread(target=cpu_thread, args=(iter_finished_event, llm_output_tokens_cpu))
thread.start()

g.replay()
while keep_generating:
    cv.wait()
    g.replay()

This code is intended to solve a niche problem: Normally, when doing LLM inference at batch size 1, you generate one token at a time, checking after each iteration whether that token is EOS. Naively done, you will synchronize the entire stream after every iteration to do this check, since you need to copy the data to cpu before checking for EOS. Under the above approach, you basically will enqueue a second g.replay() while the first is still running, thus reducing your CPU overhead.

I realize now that my code does not do what I explained. But this is something that I know people have explored before. I've done things like this in other circumstances, but it's a bit hard to make a very simple example. Nevertheless, does this example help? It took me about 20 minutes to write this out and would rather get it out instead of holding it back.

@galv
Copy link
Collaborator

galv commented Feb 9, 2025

Sorry, what I had in mind was the technique where, if you unroll N steps of LLM inference, you would use N-1 events in order to synchronize a CPU thread with each point in time at which each of the intermediate tokens is copied to CPU. This will reduce your CPU overhead of launching the cudagraph by a factor of N, at the cost of doing potentially redundant computation if your LLM does not need to run N times.

@galv
Copy link
Collaborator

galv commented Feb 9, 2025

This pseudo code describes the loop unrolling trick:

import torch

def foo(unroll_factor):
    events = []
    llm_output_tokens_cpu_list = []
    for i in range(unroll_factor):
        iter_finished_event = torch.cuda.Event(external=torch.cuda.is_current_stream_capturing())
        llm_output_tokens_gpu = do_llm_inference()
        llm_output_tokens_cpu = llm_output_tokens_gpu.to(device=torch.device("cpu")).pin_memory()
        iter_finished_event.record() # This will use use cudaEventRecordExternal in cudaEventRecordWithFlags
        events.append(iter_finished_event)
        llm_output_tokens_cpu_list.append(llm_output_tokens_cpu)
    return events, llm_output_tokens_cpu_list

g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, capture_error_mode="relaxed"): # relaxed because I am not doing warmup.
    iter_finished_events, llm_output_tokens_cpu_list = foo(unroll_factor=2)

keep_generating = True

while keep_generating:
    g.replay()
    for event, llm_output_tokens_cpu in zip(events, llm_output_tokens_cpu_list):
        event.synchronize()
        if keep_generating:
            send_response_to_client(llm_output_tokens_cpu)
        if llm_output_tokens_cpu[-1] == eos_token:
            keep_generating = False

@ngimel ngimel changed the title [CUDAEvent.h] support cuda events in cudagraphs [CUDAEvent.h] support external cuda events in cudagraphs Feb 10, 2025
@galv
Copy link
Collaborator

galv commented Feb 11, 2025

@nmacchioni so unfortunately the code I showed above doesn't work because pin_memory() does not work inside of stream capture for several reasons. pin_memory() would have to be called outside of the stream capture for my example (which is actually how I've implemented it before). Thanks to @ngimel for pointing this out. I'm taking a look at what it takes to do pin_memory() inside of stream capture. Funnily enough, it will probably require help from this PR.

@galv
Copy link
Collaborator

galv commented Feb 11, 2025

FYI here is an initial attempt at pin_memory() support during stream capture: #146924

@galv
Copy link
Collaborator

galv commented Mar 6, 2025

@nmacchioni would you like me to take this PR over? My work in #146924 dovetails nicely with it, so I an make sure that tobht use cases (timing and synchronizing with only a part of a graph) both work. My impression is that you're almost done, and it won't take much to take this over the finish line.

@nmacchioni
Copy link
Contributor Author

yeah @galv if you're able to take this PR over please do, I was pulled into a war room over here and I still don't have time to get back to this

Copy link
Contributor

github-actions bot commented May 5, 2025

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label May 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
release notes: cuda release notes category Stale
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants