Thanks to visit codestin.com
Credit goes to github.com

Skip to content

ggml-backend : fix async copy from CPU #8897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 7, 2024
Merged

Conversation

slaren
Copy link
Member

@slaren slaren commented Aug 6, 2024

Fixes #8685

The problem was that some copies from the CPU backend to the CUDA backend were not correctly synchronized, which in some cases could allow the CPU backend to overwrite the data in the next batch, before it was copied to the GPU.

@slaren
Copy link
Member Author

slaren commented Aug 6, 2024

@matteoserva please let me know if this fixes the issue in your system. I already tested this on @JohannesGaessler machine, so I expect it works there.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the destination backend need to be synchronized in ggml_backend_tensor_copy_async but not in ggml_backend_sched_compute_splits?

@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label Aug 6, 2024
@slaren slaren force-pushed the sl/fix-cpu-async-copy branch from cf49428 to a5eae7a Compare August 6, 2024 22:17
@slaren
Copy link
Member Author

slaren commented Aug 6, 2024

The idea is that the scheduler makes multiple copies of every input and synchronizes access to them with events. Instead of having to synchronize the entire backend, it is enough to synchronize with the event. However there was a missing ggml_backend_event_synchronize in this case, it should be fixed now.

if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
} else {
ggml_backend_synchronize(split_backend);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this synchronization call can be optimized out since with a null event the backend has already been synchronized. But if there is no measurable performance difference it may be better to just keep it in to make the code easier to understand.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I left it there for clarity. For backends that don't support events, ggml_backend_synchronize should be a no-op anyway.

@JohannesGaessler
Copy link
Collaborator

Prior to the latest commit the fix was working on my second machine with 3x P40. I'll review the new changes tomorrow.

@slaren
Copy link
Member Author

slaren commented Aug 6, 2024

The changes to ggml_backend_cuda_cpy_tensor_async in the latest commit are not related to this issue, and these cases are never hit in llama.cpp. Nonetheless, I found these issues while looking into this, so I am fixing it now to avoid other issues in the future.

@matteoserva
Copy link
Contributor

@slaren The patch fixed the issue on my system. Thank you!


if (backend_src != backend_dst) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is it ensured that there are no race conditions between backend_src and backend_dst for this code branch?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What race conditions are you thinking about? It uses an event to synchronize the two streams.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I misinterpreted the code. If my understanding is correct the synchronization happens outside this function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Part of the synchronization is done in this function, but the most complicated parts happen in ggml_backend_sched. Ultimately, the only responsability of this function is to implement the semantics of the copy_async interface of ggml-backend, as defined in ggml-backend.h:

    // asynchronous copy
    // the copy is performed after all the currently queued operations in backend_src
    // backend_dst will wait for the copy to complete before performing other operations
    // automatic fallback to sync copy if async is not supported
    GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);

@slaren slaren merged commit be55695 into master Aug 7, 2024
54 checks passed
@slaren slaren deleted the sl/fix-cpu-async-copy branch August 7, 2024 11:29
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Aug 7, 2024
* ggml-backend : fix async copy from CPU

* cuda : more reliable async copy, fix stream used when the devices are the same
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 18, 2024
* ggml-backend : fix async copy from CPU

* cuda : more reliable async copy, fix stream used when the devices are the same
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug: (CUDA) Corrupted output when offloading to multiple GPUs
3 participants