PyTorch implementation of DDPO for diffusion model finetuning
Top 52.7% on sourcepulse
This repository implements Denoising Diffusion Policy Optimization (DDPO) in PyTorch for finetuning diffusion models, specifically targeting Stable Diffusion. It enables users to customize image generation based on user-defined prompts and reward functions, offering a flexible approach to aligning AI image generation with specific aesthetic or functional goals.
How It Works
DDPO frames diffusion model finetuning as a reinforcement learning problem. It generates images using a diffusion model, evaluates them with a reward function, and then updates the diffusion model's policy (its parameters) to maximize expected rewards. The implementation leverages LoRA for efficient finetuning, significantly reducing memory requirements.
Quick Start & Requirements
pip install -e .
after cloning the repository.Highlighted Details
trl
library for a DDPOTrainer
.config/base.py
, config/dgx.py
) provide example settings.Maintenance & Community
trl
integration was contributed by @metric-space.Licensing & Compatibility
Limitations & Caveats
1 year ago
1 week