-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add cross attention type for Sana-Sprint training in diffusers. #11514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
elif cross_attention_type == "vanilla": | ||
cross_attention_processor = SanaAttnProcessor3_0() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we modify the SanaAttnProcessor2_0()
class to handle the changes of SanaAttnProcessor3_0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we merge 2_0 and 3_0, we then need a variance to check when to use the function here:
hidden_states = self.scaled_dot_product_attention( |
which will be similar with
cross_attention_type: str = "flash",
@@ -360,6 +453,7 @@ def __init__( | |||
guidance_embeds_scale: float = 0.1, | |||
qk_norm: Optional[str] = None, | |||
timestep_scale: float = 1.0, | |||
cross_attention_type: str = "flash", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This goes a bit against our design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then can we just separate it into two classes and let u to help for better implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, the only difference is that F.scaled_dot_product_attention
is not supported by torch.JVP. Therefore, during training we need to replace with the vanilla attention implementation. Any good idea how to merge these two? @sayakpaul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see. If that is the case, I think we should through the attention processor mechanism wherein, we use something like set_attn_processor
and use the vanilla attention processor class.
If this is only needed for training, I think we should have the following methods added to the model class:
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def attn_processors(self) -> Dict[str, AttentionProcessor]:
We can then just include the vanilla attention processor implementation in the training utility and do something like
model = SanaTransformer2DModel(...)
model.set_attn_processor(SanaVanillaAttnProcessor())
WDYT? @DN6 any suggestion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh this is cool and nusty IMO, thanks. I'll change the code.
@@ -0,0 +1,1656 @@ | |||
#!/usr/bin/env python |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is perfect! This is 100 percent the way to go here. We can include the attention processor here in a file (attention_processor.py
) and use it from there in the training script.
Based on https://github.com/huggingface/diffusers/pull/11514/files#r2077921763.
Since we're using a folder for the training script, I won't mind if we want move out the dataloader into a separate script, utilities in a separate script. But completely up to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't mind it. Could you help for this one? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, after the https://github.com/huggingface/diffusers/pull/11514/files#r2077921763 comments are addressed, I will help with that
…SanaAttnProcessor3_0` to `SanaVanillaAttnProcessor`
I have changed the code as recommended here: https://github.com/huggingface/diffusers/pull/11514/files#r2077921763. I hope it's what you mean. @sayakpaul |
Tested locally after adding SanaVanillaAttnProcessor imports — the changes work as expected. LGTM! @lawrence-cj @sayakpaul |
|
||
huggingface-cli download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers | ||
|
||
python train_sana_sprint_diffusers.py \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is perfect!
Do we want to move the dataset class into a separate file dataset.py
? I am okay if we want to do that since it's already under research_projects
.
Also, let's add a readme with instructions on how to acquire the dataset, etc. Currently, we're only using three shards I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with u. Please help for this separated script!
@scxue Help for the readme part pls!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
README updated! @sayakpaul – Feel free to share any feedback or suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking very nice. Some minor comments and we should be able to merge soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's go!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Add cross attention type for Sana-Sprint training in diffusers. @sayakpaul