DiT  by facebookresearch

PyTorch implementation for diffusion models with transformers (DiT)

created 2 years ago
7,634 stars

Top 6.9% on sourcepulse

GitHubView on GitHub
Project Summary

This repository provides the official PyTorch implementation of DiT (Scalable Diffusion Models with Transformers), a novel approach to diffusion models that replaces the U-Net backbone with a transformer architecture. It is targeted at researchers and practitioners in generative AI and computer vision who are interested in state-of-the-art image generation. The key benefit is improved scalability and performance, demonstrated by state-of-the-art FID scores on ImageNet benchmarks.

How It Works

DiT models replace the convolutional U-Net backbone commonly used in diffusion models with a transformer architecture that operates on latent patches of images. This design choice allows for better scalability, as demonstrated by the correlation between increased GFLOPS (from deeper/wider transformers or more tokens) and lower FID scores. The transformer's self-attention mechanism is hypothesized to be more effective at capturing long-range dependencies crucial for high-fidelity image generation.

Quick Start & Requirements

  • Install: conda env create -f environment.yml and conda activate DiT.
  • Prerequisites: PyTorch, Conda. CPU-only execution is possible by removing CUDA requirements from environment.yml.
  • Running Pre-trained Models: python sample.py --image-size 512 --seed 1 for DiT-XL/2 (512x512).
  • Resources: Training requires significant GPU resources (e.g., 8x A100s for DiT-XL/2 at 256x256).
  • Links: Paper, Project Page, Hugging Face Diffusers.

Highlighted Details

  • Achieves state-of-the-art FID of 2.27 on ImageNet 256x256 and 3.04 on ImageNet 512x512 with DiT-XL/2.
  • PyTorch implementation reproduces JAX results, with FP32 PyTorch weights showing marginally better FID (2.21) than the original paper's JAX results (2.27).
  • Includes scripts for sampling, parallel sampling for evaluation (FID, IS), and training with PyTorch DDP.
  • TF32 matmuls are enabled by default for faster training/sampling on Ampere GPUs.

Maintenance & Community

  • Developed by facebookresearch (William Peebles, Saining Xie).
  • Codebase borrows from OpenAI's diffusion repositories.
  • No explicit community links (Discord/Slack) are provided in the README.

Licensing & Compatibility

  • License: CC-BY-NC (Creative Commons Attribution-NonCommercial).
  • Restrictions: Non-commercial use only. Compatibility with closed-source projects is restricted due to the NC clause.

Limitations & Caveats

The CC-BY-NC license restricts commercial use. The README notes potential speedups with Flash Attention and torch.compile, and lists basic features like FID monitoring and AMP/bfloat16 support as enhancements to be added.

Health Check
Last commit

1 year ago

Responsiveness

1 day

Pull Requests (30d)
1
Issues (30d)
3
Star History
442 stars in the last 90 days

Explore Similar Projects

Starred by Lilian Weng Lilian Weng(Cofounder of Thinking Machines Lab), Patrick Kidger Patrick Kidger(Core Contributor to JAX ecosystem), and
4 more.

glow by openai

0.1%
3k
Generative flow research paper code
created 7 years ago
updated 1 year ago
Starred by Andrej Karpathy Andrej Karpathy(Founder of Eureka Labs; Formerly at Tesla, OpenAI; Author of CS 231n), Jiayi Pan Jiayi Pan(Author of SWE-Gym; AI Researcher at UC Berkeley), and
4 more.

taming-transformers by CompVis

0.1%
6k
Image synthesis research paper using transformers
created 4 years ago
updated 1 year ago
Feedback? Help us improve.