DiT  by facebookresearch

PyTorch implementation for diffusion models with transformers (DiT)

Created 2 years ago
7,825 stars

Top 6.7% on SourcePulse

GitHubView on GitHub
Project Summary

This repository provides the official PyTorch implementation of DiT (Scalable Diffusion Models with Transformers), a novel approach to diffusion models that replaces the U-Net backbone with a transformer architecture. It is targeted at researchers and practitioners in generative AI and computer vision who are interested in state-of-the-art image generation. The key benefit is improved scalability and performance, demonstrated by state-of-the-art FID scores on ImageNet benchmarks.

How It Works

DiT models replace the convolutional U-Net backbone commonly used in diffusion models with a transformer architecture that operates on latent patches of images. This design choice allows for better scalability, as demonstrated by the correlation between increased GFLOPS (from deeper/wider transformers or more tokens) and lower FID scores. The transformer's self-attention mechanism is hypothesized to be more effective at capturing long-range dependencies crucial for high-fidelity image generation.

Quick Start & Requirements

  • Install: conda env create -f environment.yml and conda activate DiT.
  • Prerequisites: PyTorch, Conda. CPU-only execution is possible by removing CUDA requirements from environment.yml.
  • Running Pre-trained Models: python sample.py --image-size 512 --seed 1 for DiT-XL/2 (512x512).
  • Resources: Training requires significant GPU resources (e.g., 8x A100s for DiT-XL/2 at 256x256).
  • Links: Paper, Project Page, Hugging Face Diffusers.

Highlighted Details

  • Achieves state-of-the-art FID of 2.27 on ImageNet 256x256 and 3.04 on ImageNet 512x512 with DiT-XL/2.
  • PyTorch implementation reproduces JAX results, with FP32 PyTorch weights showing marginally better FID (2.21) than the original paper's JAX results (2.27).
  • Includes scripts for sampling, parallel sampling for evaluation (FID, IS), and training with PyTorch DDP.
  • TF32 matmuls are enabled by default for faster training/sampling on Ampere GPUs.

Maintenance & Community

  • Developed by facebookresearch (William Peebles, Saining Xie).
  • Codebase borrows from OpenAI's diffusion repositories.
  • No explicit community links (Discord/Slack) are provided in the README.

Licensing & Compatibility

  • License: CC-BY-NC (Creative Commons Attribution-NonCommercial).
  • Restrictions: Non-commercial use only. Compatibility with closed-source projects is restricted due to the NC clause.

Limitations & Caveats

The CC-BY-NC license restricts commercial use. The README notes potential speedups with Flash Attention and torch.compile, and lists basic features like FID monitoring and AMP/bfloat16 support as enhancements to be added.

Health Check
Last Commit

1 year ago

Responsiveness

Inactive

Pull Requests (30d)
0
Issues (30d)
0
Star History
120 stars in the last 30 days

Explore Similar Projects

Starred by Andrej Karpathy Andrej Karpathy(Founder of Eureka Labs; Formerly at Tesla, OpenAI; Author of CS 231n), Jiayi Pan Jiayi Pan(Author of SWE-Gym; MTS at xAI), and
15 more.

taming-transformers by CompVis

0.1%
6k
Image synthesis research paper using transformers
Created 4 years ago
Updated 1 year ago
Starred by Chip Huyen Chip Huyen(Author of "AI Engineering", "Designing Machine Learning Systems"), Soumith Chintala Soumith Chintala(Coauthor of PyTorch), and
1 more.

jetson-inference by dusty-nv

0.1%
9k
Vision DNN library for NVIDIA Jetson devices
Created 9 years ago
Updated 11 months ago
Feedback? Help us improve.