minimal-diffusion  by VSehwag

Minimal diffusion model implementation for synthetic data generation

created 3 years ago
290 stars

Top 91.7% on sourcepulse

GitHubView on GitHub
Project Summary

This repository provides a minimal yet resourceful implementation of diffusion models, offering pre-trained models and synthetic data for nine diverse datasets. It aims to democratize the use of synthetic data generated by diffusion models, demonstrating its effectiveness for training downstream classifiers. The project is suitable for researchers and practitioners interested in generative models and synthetic data generation.

How It Works

The project implements class-conditional diffusion models using a UNet architecture. It trains these models on various datasets and then samples synthetic images. The quality and diversity of the synthetic data are evaluated by training a ResNet50 classifier solely on the generated images and comparing its accuracy on real data against a classifier trained on real data. This approach highlights diffusion models' ability to generate high-quality, diverse data that can be competitive with real data for certain downstream tasks, especially in low-data regimes.

Quick Start & Requirements

  • Install: pip install scipy opencv-python (assumes torch and torchvision are installed).
  • Requirements: PyTorch, SciPy, OpenCV.
  • Training: CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py --arch UNet --dataset cifar10 --class-cond --epochs 500 (example for 4 GPUs).
  • Sampling: CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py --arch UNet --dataset cifar10 --class-cond --sampling-only --sampling-steps 250 --num-sampled-images 50000 --pretrained-ckpt path_to_pretrained_model (example for sampling).
  • Scripts for training and sampling are available in the ./scripts directory.

Highlighted Details

  • Synthetic data alone achieves competitive classification scores on real data across multiple datasets.
  • Demonstrates superior performance with synthetic data compared to limited real data for datasets like Flowers and Cars.
  • Diffusion models are noted for ease of training, excellent mode coverage, photorealism, and consistent training pipelines.
  • Released assets include pre-trained diffusion models, 50,000 synthetic images per dataset, and downstream classifiers.

Maintenance & Community

The implementation is motivated by Jonathan Ho's original diffusion model work and OpenAI's PyTorch implementation. The project's experimental findings are inspired by the author's previous work on synthetic data.

Licensing & Compatibility

The README does not explicitly state a license. Real image licenses follow their respective datasets.

Limitations & Caveats

The project uses lower-resolution images (64x64 for most datasets) to reduce computational resources, which can make classification harder. Hyperparameters for downstream classifiers are not tuned across datasets. The classification numbers are not intended to be state-of-the-art.

Health Check
Last commit

11 months ago

Responsiveness

Inactive

Pull Requests (30d)
0
Issues (30d)
0
Star History
8 stars in the last 90 days

Explore Similar Projects

Feedback? Help us improve.