vqgan-training  by cloneofsimo

VAE trainer for latent diffusion models

Created 11 months ago
292 stars

Top 90.4% on SourcePulse

GitHubView on GitHub
Project Summary

This repository provides a distributed VAE trainer designed for training Variational Autoencoders (VAEs) used in latent diffusion models like Stable Diffusion. It targets researchers and engineers working with generative AI who need to train high-quality VAEs for image synthesis tasks. The primary benefit is enabling efficient, stable, and high-fidelity VAE training through advanced techniques.

How It Works

The trainer leverages PyTorch's DistributedDataParallel (DDP) for multi-GPU acceleration. It incorporates GAN loss using a VGG16-based discriminator and hinge loss for improved image quality. Perceptual loss is handled by LPIPS. For training stability, it implements gradient normalization and a fixed variance of 0.1, deviating from standard learnable variance. Reconstruction loss is a combination of LPIPS and a modified MSE that operates on a low-pass filtered version of the image to balance detail and color accuracy.

Quick Start & Requirements

  • Primary install / run command: torchrun --nproc_per_node=8 vae_trainer.py
  • Prerequisites: Dataset in webdataset format (can be created with img2dataset), PyTorch.
  • Configuration: Supports arguments like --learning_rate_vae, --vae_ch, --vae_ch_mult, --do_ganloss.

Highlighted Details

  • Utilizes GAN loss with hinge loss and a discriminator for enhanced image quality.
  • Implements gradient normalization and fixed variance for stable training.
  • Employs a low-pass filtered MSE loss alongside LPIPS to balance reconstruction fidelity and color calibration.
  • Based on FLUX's VAE architecture with modifications for multi-head attention and constant variance.

Maintenance & Community

  • Maintained by Simo Ryu.
  • Citation provided for academic use.

Licensing & Compatibility

  • License: Not explicitly stated in the README.

Limitations & Caveats

The README does not specify the license, which could impact commercial use or integration into closed-source projects. The setup requires a dataset in a specific format (webdataset), necessitating an additional preprocessing step.

Health Check
Last Commit

11 months ago

Responsiveness

Inactive

Pull Requests (30d)
0
Issues (30d)
0
Star History
0 stars in the last 30 days

Explore Similar Projects

Feedback? Help us improve.