PyTorch implementation for diffusion models with transformers (DiT)
Top 6.9% on sourcepulse
This repository provides the official PyTorch implementation of DiT (Scalable Diffusion Models with Transformers), a novel approach to diffusion models that replaces the U-Net backbone with a transformer architecture. It is targeted at researchers and practitioners in generative AI and computer vision who are interested in state-of-the-art image generation. The key benefit is improved scalability and performance, demonstrated by state-of-the-art FID scores on ImageNet benchmarks.
How It Works
DiT models replace the convolutional U-Net backbone commonly used in diffusion models with a transformer architecture that operates on latent patches of images. This design choice allows for better scalability, as demonstrated by the correlation between increased GFLOPS (from deeper/wider transformers or more tokens) and lower FID scores. The transformer's self-attention mechanism is hypothesized to be more effective at capturing long-range dependencies crucial for high-fidelity image generation.
Quick Start & Requirements
conda env create -f environment.yml
and conda activate DiT
.environment.yml
.python sample.py --image-size 512 --seed 1
for DiT-XL/2 (512x512).Highlighted Details
Maintenance & Community
Licensing & Compatibility
Limitations & Caveats
The CC-BY-NC license restricts commercial use. The README notes potential speedups with Flash Attention and torch.compile
, and lists basic features like FID monitoring and AMP/bfloat16 support as enhancements to be added.
1 year ago
1 day