Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Add VidTok AutoEncoders #11261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

annitang1997
Copy link

We add VidTok, a versatile and state-of-the-art video tokenizer, as an autoencoder model to diffusers.

Paper: https://arxiv.org/pdf/2412.13061
Code: https://github.com/microsoft/VidTok
Model: https://huggingface.co/microsoft/VidTok

@a-r-r-o-w
Copy link
Member

Thank you for the PR @annitang1997! I will review this in depth soon. cc @yiyixuxu too

@deeptimhe
Copy link

deeptimhe commented Apr 20, 2025

Is there any updates on the review process? 👀 Looking forward to use VidTok with diffusers.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR and congratulations for the release of your awesome work!

I did a first pass review about some changes that need to be made to make the implementation similar to remaining of the diffusers codebase. There are some core implementation details that will have to be refactored before we can merge. A good reference implementation for autoencoders can be found here:

I'd be happy to help assist in making some of these changes! 🤗

@@ -688,6 +689,158 @@ def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...])
return z_q


class FSQRegularizer(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're moving towards maintaining a single file per modeling implementation, and so let's move this to the vidtok autoencoder file

@@ -285,6 +285,27 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return F.conv2d(inputs, weight, stride=2)


class VidTokDownsample2D(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to vidtok autoencoder file as well

@@ -470,6 +471,28 @@ def forward(
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]


class VidTokLayerNorm(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to vidtok autoencoder file as well

@@ -356,6 +356,26 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)


class VidTokUpsample2D(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to vidtok autoencoder file as well

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import pack, rearrange, unpack
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to replace all einops operations with permute/reshape/other ops since it adds another dependancy which we don't use in the codebase

Comment on lines +604 to +610

def create_custom_forward(module):
def custom_forward(*inputs):
return module.downsample(*inputs)

return custom_forward

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def create_custom_forward(module):
def custom_forward(*inputs):
return module.downsample(*inputs)
return custom_forward

if i_level in self.spatial_ds:
# spatial downsample
htmp = rearrange(hs[-1], "b c t h w -> (b t) c h w")
htmp = torch.utils.checkpoint.checkpoint(create_custom_forward(self.down[i_level]), htmp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
htmp = torch.utils.checkpoint.checkpoint(create_custom_forward(self.down[i_level]), htmp)
htmp = self._gradient_checkpointing_func(self.down[i_level], htmp)

B, _, T, H, W = htmp.shape
# middle
h = hs[-1]
h = torch.utils.checkpoint.checkpoint(self.mid.block_1, h, temb)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above for these usages

return h


class AutoencoderVidTok(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class AutoencoderVidTok(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderVidTok(ModelMixin, ConfigMixin):

self.tile_overlap_factor_width = 0.0 # 1 / 8

@staticmethod
def pad_at_dim(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any methods that are not to be directly invoked by users should be made private (that is prefix with an underscore _pad_at_dim)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants