Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Aug 7, 2023

What does this PR do?

This PR is similar in spirit to #4114 .

Every pipeline can run enable_model_cpu_offload so this is a method we can move to PipelineModelMixin to remove some of the boilerplate code here.

Since every pipeline has a slightly different chain in which models should be on- and offloaded we need to add a class attribute that defines this chain of strings.

Also this PR adds a free_hooks method that should be called at the end of every Pipeline's call function. This method should be more robust than what we currently have and also solve bugs as the following: #2907

TODO:

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 7, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor Author

In this PR we should also nicely solve the following issue: #4435 (comment)

Simply because we will just offload all components in the maybe_free_model_hooks call.

@Kubuxu feel free to give this PR also a review

@patrickvonplaten
Copy link
Contributor Author

@DN6 could you maybe try to take over this PR?

@patrickvonplaten
Copy link
Contributor Author

Any progress here @DN6 ?

@DN6
Copy link
Collaborator

DN6 commented Sep 4, 2023

@patrickvonplaten Handling it this week

Copy link

@Kubuxu Kubuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the perspective of #4435 it solves it nicely.

@DN6
Copy link
Collaborator

DN6 commented Sep 11, 2023

@patrickvonplaten This is ready for another review.

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Sep 11, 2023

Looks good to me! Think once the merge conflicts are corrected and once we have verified that all:

pytest tests/pipelines -k "test_model_cpu_offload_forward_pass"

works on GPU we can merge this I think.
Note that these tests don't run on PRs as they require a GPU. Did you check that they all pass?

Also I think we should slightly change the offloading method in the end: https://github.com/huggingface/diffusers/pull/4514/files#r1321307022 (wdyt?)

@DN6
Copy link
Collaborator

DN6 commented Sep 11, 2023

@patrickvonplaten Getting two failures at the moment when testing. Both from Shap E. enable_full_determinism isn't set on those tests.

================================================ short test summary info =================================================
FAILED tests/pipelines/shap_e/test_shap_e.py::ShapEPipelineFastTests::test_model_cpu_offload_forward_pass - RuntimeError: cumsum_cuda_kernel does not have a deterministic implementation, but you set 'torch.use_deterministic_a...
FAILED tests/pipelines/shap_e/test_shap_e_img2img.py::ShapEImg2ImgPipelineFastTests::test_model_cpu_offload_forward_pass - RuntimeError: cumsum_cuda_kernel does not have a deterministic implementation, but you set 'torch.use_deterministic_a...

@patrickvonplaten
Copy link
Contributor Author

Ok let's merge regardless of the failures and solve that afterward. Can you fix the merge conflicts and then we merge?

@DN6
Copy link
Collaborator

DN6 commented Sep 11, 2023

@patrickvonplaten Merge conflicts resolved and added in your suggestions. There's a failing doc test, but I'm not able to reproduce it locally. Any idea what the issue might be?

"""

_load_connected_pipes = True
model_cpu_offload_seq = "text_encoder->unet->movq->prior_prior->prior_image_encoder->prior_text_encoder"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

latents = latents * self.scheduler.init_noise_sigma
return latents

def enable_model_cpu_offload(self, gpu_id=0):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for now, but why not use the default way of model offloading here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something I noticed while working on this.

Certain pipelines (AudioLDM2, MusicLDM, Shap E) do not make use of the forward method of their components. Instead they pass inputs into submodules of the component.

prompt_embeds = self.text_encoder.get_text_features(

This leads to a device mismatch error since accelerate only moves the module back to GPU when forward is called.

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu...

IMO, if Pipelines are using submodules of the components during inference, I think it's fine for them to implement their own enable_model_cpu_offload since it can be challenging for us to know exactly which modules to offload.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR, I've cleaned up the enable_model_cpu_offload in the problematic pipelines to properly offload the submodules so that users still get the expected memory savings. Alternatively, we could move these problematic modules into the _exclude_from_cpu_offload list and use the enable_model_cpu_offload defined in DiffusionPipeline but that would affect memory savings.

@patrickvonplaten patrickvonplaten merged commit 9357965 into main Sep 11, 2023
@patrickvonplaten patrickvonplaten deleted the refactor_model_offload branch September 11, 2023 17:39
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* [Draft] Refactor model offload

* [Draft] Refactor model offload

* Apply suggestions from code review

* cpu offlaod updates

* remove model cpu offload from individual pipelines

* add hook to offload models to cpu

* clean up

* model offload

* add model cpu offload string

* make style

* clean up

* fixes for offload issues

* fix tests issues

* resolve merge conflicts

* update src/diffusers/pipelines/pipeline_utils.py

Co-authored-by: Patrick von Platen <[email protected]>

* make style

* Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

---------

Co-authored-by: Dhruv Nair <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* [Draft] Refactor model offload

* [Draft] Refactor model offload

* Apply suggestions from code review

* cpu offlaod updates

* remove model cpu offload from individual pipelines

* add hook to offload models to cpu

* clean up

* model offload

* add model cpu offload string

* make style

* clean up

* fixes for offload issues

* fix tests issues

* resolve merge conflicts

* update src/diffusers/pipelines/pipeline_utils.py

Co-authored-by: Patrick von Platen <[email protected]>

* make style

* Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

---------

Co-authored-by: Dhruv Nair <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants