Minimal diffusion model implementation for synthetic data generation
Top 91.7% on sourcepulse
This repository provides a minimal yet resourceful implementation of diffusion models, offering pre-trained models and synthetic data for nine diverse datasets. It aims to democratize the use of synthetic data generated by diffusion models, demonstrating its effectiveness for training downstream classifiers. The project is suitable for researchers and practitioners interested in generative models and synthetic data generation.
How It Works
The project implements class-conditional diffusion models using a UNet architecture. It trains these models on various datasets and then samples synthetic images. The quality and diversity of the synthetic data are evaluated by training a ResNet50 classifier solely on the generated images and comparing its accuracy on real data against a classifier trained on real data. This approach highlights diffusion models' ability to generate high-quality, diverse data that can be competitive with real data for certain downstream tasks, especially in low-data regimes.
Quick Start & Requirements
pip install scipy opencv-python
(assumes torch
and torchvision
are installed).CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py --arch UNet --dataset cifar10 --class-cond --epochs 500
(example for 4 GPUs).CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py --arch UNet --dataset cifar10 --class-cond --sampling-only --sampling-steps 250 --num-sampled-images 50000 --pretrained-ckpt path_to_pretrained_model
(example for sampling)../scripts
directory.Highlighted Details
Maintenance & Community
The implementation is motivated by Jonathan Ho's original diffusion model work and OpenAI's PyTorch implementation. The project's experimental findings are inspired by the author's previous work on synthetic data.
Licensing & Compatibility
The README does not explicitly state a license. Real image licenses follow their respective datasets.
Limitations & Caveats
The project uses lower-resolution images (64x64 for most datasets) to reduce computational resources, which can make classification harder. Hyperparameters for downstream classifiers are not tuned across datasets. The classification numbers are not intended to be state-of-the-art.
11 months ago
Inactive