-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add VidTok AutoEncoders #11261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add VidTok AutoEncoders #11261
Conversation
Thank you for the PR @annitang1997! I will review this in depth soon. cc @yiyixuxu too |
Is there any updates on the review process? 👀 Looking forward to use VidTok with diffusers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR and congratulations for the release of your awesome work!
I did a first pass review about some changes that need to be made to make the implementation similar to remaining of the diffusers codebase. There are some core implementation details that will have to be refactored before we can merge. A good reference implementation for autoencoders can be found here:
- https://github.com/huggingface/diffusers/blob/0dec414d5bf2c7fe77684722b0a97324798bd7b3/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
- https://github.com/huggingface/diffusers/blob/0dec414d5bf2c7fe77684722b0a97324798bd7b3/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
I'd be happy to help assist in making some of these changes! 🤗
@@ -688,6 +689,158 @@ def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) | |||
return z_q | |||
|
|||
|
|||
class FSQRegularizer(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're moving towards maintaining a single file per modeling implementation, and so let's move this to the vidtok autoencoder file
@@ -285,6 +285,27 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |||
return F.conv2d(inputs, weight, stride=2) | |||
|
|||
|
|||
class VidTokDownsample2D(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this to vidtok autoencoder file as well
@@ -470,6 +471,28 @@ def forward( | |||
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] | |||
|
|||
|
|||
class VidTokLayerNorm(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this to vidtok autoencoder file as well
@@ -356,6 +356,26 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |||
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1) | |||
|
|||
|
|||
class VidTokUpsample2D(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this to vidtok autoencoder file as well
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from einops import pack, rearrange, unpack |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to replace all einops operations with permute/reshape/other ops since it adds another dependancy which we don't use in the codebase
|
||
def create_custom_forward(module): | ||
def custom_forward(*inputs): | ||
return module.downsample(*inputs) | ||
|
||
return custom_forward | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module.downsample(*inputs) | |
return custom_forward |
if i_level in self.spatial_ds: | ||
# spatial downsample | ||
htmp = rearrange(hs[-1], "b c t h w -> (b t) c h w") | ||
htmp = torch.utils.checkpoint.checkpoint(create_custom_forward(self.down[i_level]), htmp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
htmp = torch.utils.checkpoint.checkpoint(create_custom_forward(self.down[i_level]), htmp) | |
htmp = self._gradient_checkpointing_func(self.down[i_level], htmp) |
B, _, T, H, W = htmp.shape | ||
# middle | ||
h = hs[-1] | ||
h = torch.utils.checkpoint.checkpoint(self.mid.block_1, h, temb) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above for these usages
return h | ||
|
||
|
||
class AutoencoderVidTok(ModelMixin, ConfigMixin, FromOriginalModelMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class AutoencoderVidTok(ModelMixin, ConfigMixin, FromOriginalModelMixin): | |
class AutoencoderVidTok(ModelMixin, ConfigMixin): |
self.tile_overlap_factor_width = 0.0 # 1 / 8 | ||
|
||
@staticmethod | ||
def pad_at_dim( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any methods that are not to be directly invoked by users should be made private (that is prefix with an underscore _pad_at_dim
)
We add VidTok, a versatile and state-of-the-art video tokenizer, as an autoencoder model to diffusers.
Paper: https://arxiv.org/pdf/2412.13061
Code: https://github.com/microsoft/VidTok
Model: https://huggingface.co/microsoft/VidTok