Thanks to visit codestin.com
Credit goes to github.com

Skip to content

is0383kk/Pytorch_VAE-GMM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

86 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Variational Auto-Encoder(VAE)+Gaussian mixture model(GMM)

Implementation of mutual learning model between VAE and GMM.
This idea of integrating probability models is based on this paper: Neuro-SERKET: Development of Integrative Cognitive System through the Composition of Deep Probabilistic Generative Models.
Symbol Emergence in Robotics tool KIT(SERKET) is a framework that allows integration and partitioning of probabilistic generative models.

This is a Graphical Model of VAE+GMM model:

VAE and GMM share the latent variable x.
x is a variable that follows a multivariate normal distribution and is estimated by VAE.

The training will be conducted in the following sequence.

  1. VAE estimates latent variable(x) and sends latent variables(x) to GMM.
  2. GMM clusters latent variables(x) sent from VAE and sends mean and variance parameters of the Gaussian distribution to VAE.
  3. Return to 1 again.

What this repo contains:

  • main.py: Main code for training model.
  • vae_module.py: A training program for VAE, running in main.py.
  • gmm_module.py: A training program for GMM, running in main.py.
  • tool.py: Various functions handled in the program.

How to run

Install the required libraries using the following command. ※ Install PyTorch first (XXX should match your CUDA version).
※ My environment is the following Pytorch==2.8.0+cu129, CUDA==12.9

$ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cuXXX
$ pip install -r requirements.txt

You can train the VAE+GMM model by running main.py.

$ python main.py
  • train_model() can be made to train VAE+GMM.
  • vae_module.decode() makes image reconstruction from parameters of posterior distribution estimated by GMM.
def main() -> None:
    """Main function to orchestrate the VAE-GMM training process."""
    # Parse arguments and setup configuration
    config = parse_arguments()

    # Setup environment
    setup_directories(config)
    device = setup_device_and_seed(config)

    # Create data loaders
    train_loader, all_loader, train_size = create_data_loaders(config)

    # Train the VAE-GMM model
    train_model(config, train_loader, all_loader, device)

    # Reconstruct images from trained model
    print("\nGenerating reconstructed images...")
    vae_module.decode(
        iteration=1,  # Use model from iteration 1
        decode_k=1,  # Use cluster 1 for reconstruction
        sample_num=16,  # Generate 16 samples
        model_dir=config.debug_dir,
        device=device,
    )

Changes with and without mutual learning (for MNIST)

Latent space on VAE

Left : without mutual learning・Right : with mutual learning
Plot using TSNE

Plot using PCA

ELBO of VAE

Red line is ELBO before mutual learning, Blue line is ELBO after mutual learning
Vertical axis is training iteration of VAE, Horizontal axis is ELBO of VAE
(In general, the higher the ELBO, the better)

Clustering performance (in GMM)

Results of clustering performance by accuracy(Addresses clustering performance in GMM within VAE+GMM)
Left : without mutual learning・Right : with mutual learning
Vertical axis is training iteration of GMM, Horizontal axis is accuracy

Image reconstruction from Gaussian distribution parameters estimated by GMM using VAE decoder

GMM performs clustering on latent variables of VAE. By sampling random variables from posterior distribution estimated by GMM and using them as input to VAE decoder, the image can be reconstructed.

"x" represents the mean parameter of the normal distribution for each cluster.
In this example, a random variable is sampled from a Gaussian distribution with K=1.

Reconstructed image of the sampled random variable input to the VAE decoder:

Special Thanks

The implementation of GMM is based on 【Python】4.4.2:ガウス混合モデルにおける推論:ギブスサンプリング【緑ベイズ入門のノート】