Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Add cross attention type for Sana-Sprint training in diffusers. #11514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 8, 2025

Conversation

scxue
Copy link
Contributor

@scxue scxue commented May 7, 2025

Add cross attention type for Sana-Sprint training in diffusers. @sayakpaul

Comment on lines 315 to 316
elif cross_attention_type == "vanilla":
cross_attention_processor = SanaAttnProcessor3_0()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we modify the SanaAttnProcessor2_0() class to handle the changes of SanaAttnProcessor3_0?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we merge 2_0 and 3_0, we then need a variance to check when to use the function here:

hidden_states = self.scaled_dot_product_attention(

which will be similar with cross_attention_type: str = "flash",

@sayakpaul

@@ -360,6 +453,7 @@ def __init__(
guidance_embeds_scale: float = 0.1,
qk_norm: Optional[str] = None,
timestep_scale: float = 1.0,
cross_attention_type: str = "flash",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This goes a bit against our design.

Copy link
Contributor

@lawrence-cj lawrence-cj May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then can we just separate it into two classes and let u to help for better implementation?

Copy link
Contributor

@lawrence-cj lawrence-cj May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the only difference is that F.scaled_dot_product_attention is not supported by torch.JVP. Therefore, during training we need to replace with the vanilla attention implementation. Any good idea how to merge these two? @sayakpaul

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. If that is the case, I think we should through the attention processor mechanism wherein, we use something like set_attn_processor and use the vanilla attention processor class.

If this is only needed for training, I think we should have the following methods added to the model class:

We can then just include the vanilla attention processor implementation in the training utility and do something like

model = SanaTransformer2DModel(...)
model.set_attn_processor(SanaVanillaAttnProcessor())

WDYT? @DN6 any suggestion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is cool and nusty IMO, thanks. I'll change the code.

@@ -0,0 +1,1656 @@
#!/usr/bin/env python
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is perfect! This is 100 percent the way to go here. We can include the attention processor here in a file (attention_processor.py) and use it from there in the training script.

Based on https://github.com/huggingface/diffusers/pull/11514/files#r2077921763.

Since we're using a folder for the training script, I won't mind if we want move out the dataloader into a separate script, utilities in a separate script. But completely up to you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind it. Could you help for this one? :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, after the https://github.com/huggingface/diffusers/pull/11514/files#r2077921763 comments are addressed, I will help with that

@lawrence-cj
Copy link
Contributor

lawrence-cj commented May 8, 2025

I have changed the code as recommended here: https://github.com/huggingface/diffusers/pull/11514/files#r2077921763. I hope it's what you mean. @sayakpaul
Let @scxue help to check if my change is correct.

@scxue
Copy link
Contributor Author

scxue commented May 8, 2025

Tested locally after adding SanaVanillaAttnProcessor imports — the changes work as expected. LGTM! @lawrence-cj @sayakpaul


huggingface-cli download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers

python train_sana_sprint_diffusers.py \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is perfect!

Do we want to move the dataset class into a separate file dataset.py? I am okay if we want to do that since it's already under research_projects.

Also, let's add a readme with instructions on how to acquire the dataset, etc. Currently, we're only using three shards I think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with u. Please help for this separated script!

@scxue Help for the readme part pls!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

README updated! @sayakpaul – Feel free to share any feedback or suggestions.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking very nice. Some minor comments and we should be able to merge soon.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul sayakpaul merged commit 784db0e into huggingface:main May 8, 2025
8 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants