PyTorch implementation of Karras et al. (2022) diffusion models
Top 19.2% on sourcepulse
This repository provides an implementation of the Karras et al. (2022) diffusion models for PyTorch, targeting researchers and practitioners in generative AI. It offers enhanced sampling algorithms, transformer-based diffusion models, and utilities for training and evaluation, aiming to improve sample quality and training efficiency.
How It Works
The core of k-diffusion is its implementation of diffusion models, including the Karras et al. (2022) paper's techniques. It introduces a novel image_transformer_v2
model type, inspired by Hourglass Transformer and DiT, which utilizes hierarchical transformers. This architecture employs efficient attention mechanisms like neighborhood attention (via NATTEN) and global attention (via FlashAttention-2) at different levels of the hierarchy, allowing for a flexible trade-off between performance and custom CUDA kernel requirements. It also incorporates soft Min-SNR loss weighting for improved high-resolution training.
Quick Start & Requirements
pip install k-diffusion
pip install -e <path to repository>
image_transformer_v2
.torch.compile()
.python train.py --config configs/config_oxford_flowers_shifted_window.json --name flowers_demo_001 --evaluate-n 0 --batch-size 32 --sample-n 36 --mixed-precision bf16
Highlighted Details
Maintenance & Community
Licensing & Compatibility
Limitations & Caveats
The shifted window attention variant of image_transformer_v2
performs worse and is slower than the NATTEN-based version. Models trained with one attention type require fine-tuning to be used with a different type. The inference section is marked as "TODO".
6 months ago
1 day