😈 This repository offers an unofficial PyTorch implementation of the paper Mean Flows for One-step Generative Modeling, building upon Just-a-DiT and EzAudio.
💬 Contributions and feedback are very welcome — feel free to open an issue or pull request if you spot something or have ideas!
🛠️ This codebase is kept as clean and minimal as possible for easier integration into your own projects — thus, frameworks like Wandb are intentionally excluded.
Sorry, I’ve been busy with other projects lately and haven’t updated this repo to support more functions.
Recently, rcm released JVP in Triton, which is insane — now you can use Flash Attention + MeanFlow.
MNIST -- 10k training steps, 1-step sample result:
MNIST -- 6k training steps, 1-step CFG (w=2.0) sample result:
CIFAR-10 -- 200k training steps, 1-step CFG (w=2.0) sample result:
- Implement basic training and inference
- Enable multi-GPU training via 🤗 Accelerate
- Add support for Classifier-Free Guidance (CFG)
- Integrate latent image representation support
- Add tricks like improved CFG mentioned in Appendix
jvpis incompatible with Flash Attention and likely also with Triton, Mamba, and similar libraries.jvpsignificantly increases GPU memory usage, even when usingtorch.utils.checkpoint.- CFG is implemented implicitly, leading to some limitations:
- The CFG scale is fixed at training time and cannot be adjusted during inference.
- Negative prompts are not supported, such as "noise" or "low quality" commonly used in text-to-image diffusion models.
If you find this repo helpful or interesting, consider dropping a ⭐ — it really helps and means a lot!