This repository contains the official implementation of the paper "Learning Plug-and-play Memory for Guiding Video Diffusion Models".
- [2025-11-24] Code and paper released. We also release our training data, memory data, and DiT-Mem-1.3B training weights.
- Introduction
- Method
- Installation
- Data Preparation
- Model Weights
- Training
- Inference
- Evaluation
- Citation
- Acknowledgements
DiT-Mem is a plug-and-play memory module for DiT-based video diffusion models designed to inject rich world knowledge during generation. Instead of scaling model size or data, DiT-Mem retrieves a few relevant reference videos and encodes them into compact memory tokens using 3D CNNs, frequency-domain filtering (HPF/LPF), and lightweight attention. These tokens are then inserted into the DiT’s self-attention layers to guide generation without modifying the backbone.
Our method requires finetuning only a small memory encoder on 10K videos while keeping the diffusion model frozen. When applied to Wan2.1 and Wan2.2, DiT-Mem improves controllability, semantic reasoning, and physics consistency—often surpassing strong commercial systems—while remaining efficient and fully modular.
Our framework consists of three main steps:
- Retrieval: Given a text prompt, we retrieve relevant reference videos from an external memory bank.
- Memory Encoding: A lightweight encoder processes these videos using 3D CNNs for downsampling, frequency-domain filters (Low-Pass/High-Pass) for feature disentanglement, and self-attention for aggregation.
- Injection: The resulting memory tokens are concatenated with the hidden states of the frozen DiT backbone during inference, providing guidance without altering the original model weights.
-
Clone the repository:
git clone https://github.com/Thrcle421/DiT-Mem.git cd DiT-Mem -
Create a conda environment and install dependencies:
conda create -n dit_mem python=3.10 conda activate dit_mem pip install -r requirements.txt
Our training and memory data are derived from OpenVid-1M, specifically the OpenVidHD-0.4M subset.
- Training Data: We randomly selected 10k videos from OpenVidHD-0.4M, weighted by the volume of each part.
- Memory Data: We used the remaining videos from OpenVidHD-0.4M, excluding 100 videos reserved for benchmark testing.
- CSV Files: Please download the corresponding CSV files from HuggingFace Dataset and place them in the
data/directory. - Video Data: Download the full OpenVidHD-0.4M video dataset and place it in the
video/directory.
# Example structure
data/
├── train.csv # 10k training samples
└── memory.csv # Memory bank videos
video/
└── ... # Video filesTo build the retrieval index, follow these steps:
-
Download Model: Download the Alibaba-NLP/gte-base-en-v1.5 model and place it in
model/gte-base-en-v1.5. -
Build Index: Run the following command to generate
labels.indexandid_map.json:python memory_index/build_retrieve_index.py
This will create:
memory_index/labels.index: FAISS index for retrieval.memory_index/id_map.json: Mapping from IDs to video paths.
To accelerate training and inference, we pre-compute VAE latents for all videos in the memory bank.
-
Run Pre-computation:
bash latent_processing/vae_latent_processing.sh
Ensure that
CSV_FILEin the script points to your memory data CSV (e.g.,data/memory.csv). The encoded latents will be saved in thelatent/directory.
- Base Model: Download the Wan2.1-T2V-1.3B model and place it in
model/Wan2.1-T2V-1.3B. - DiT-Mem Checkpoint: Download our trained checkpoint from HuggingFace Model and place it in
checkpoint/DiT-Mem-1.3B.safetensors.
Structure:
DiT-Mem/
├── checkpoint/
│ └── DiT-Mem-1.3B.safetensors
├── model/
│ ├── Wan2.1-T2V-1.3B/
│ │ ├── config.json
│ │ └── ...
│ └── gte-base-en-v1.5/
│ ├── config.json
│ └── ...
To train the memory encoder, use the training script:
bash scripts/train_dit_mem.shTraining config is located at config/train_dit_mem.yaml.
To generate videos using DiT-Mem, run the provided script:
bash inference/generate_videos.shParameters in inference/generate_videos.sh:
CHECKPOINT_PATH: Path to the DiT-Mem checkpoint (prepared in Model Weights).BASE_MODEL: Path to the frozen base model (prepared in Model Weights).CSV_FILE: Input CSV containing prompts.RETRIEVAL_K: Number of reference videos to retrieve (default: 5).NUM_INFERENCE_STEPS: Number of denoising steps (default: 40).
We provide scripts to evaluate DiT-Mem on two public benchmarks:
-
VBench
- Script:
evaluation/vbench/run_vbench_evaluation.sh - Official project page: VBench project page
- Script:
-
PhyGenBench
- Script:
evaluation/phygenbench/run_phygenbench_evaluation.sh - Official project page: PhyGenBench project page
- Script:
For detailed instructions, please refer to evaluation/README.md.
If you find this work useful, please cite our paper:
@article{song2025learning,
title={Learning Plug-and-play Memory for Guiding Video Diffusion Models},
author={Song, Selena and Xu, Ziming and Zhang, Zijun and Zhou, Kun and Guo, Jiaxian and Qin, Lianhui and Huang, Biwei},
journal={arXiv preprint arXiv:2511.19229},
year={2025}
}This codebase is built upon DiffSynth-Studio and Wan2.1. We also acknowledge the use of OpenVid-1M for training data and gte-base-en-v1.5 for the retrieval model. We thank VBench and PhyGenBench for their evaluation benchmarks. We thank the authors for their open-source contributions.