doremi  by sangmichaelxie

PyTorch for optimizing data mixtures in language model datasets

created 2 years ago
340 stars

Top 82.2% on sourcepulse

GitHubView on GitHub
Project Summary

DoReMi provides a PyTorch implementation for optimizing data mixture weights in language modeling datasets, targeting researchers and practitioners aiming to improve LLM training efficiency. It addresses the challenge of determining optimal domain proportions for training data by using Distributionally Robust Optimization (DRO) to create a data mixture that is robust across various downstream tasks, leading to faster convergence and improved performance.

How It Works

DoReMi employs a minimax optimization strategy. It trains a small proxy model using DRO, which dynamically adjusts domain weights based on the proxy model's excess loss relative to a pretrained reference model. This reference model anchors the optimization, preventing excessive pessimism for high-entropy domains. The resulting optimized data mixture is then used to train a larger language model, achieving significant speedups and performance gains.

Quick Start & Requirements

  • Install via pip install -e doremi after cloning the repository.
  • Requires PyTorch.
  • FlashAttention2 compilation (bash scripts/setup_flash.sh) can take hours.
  • Configuration of paths and environment variables in constants.sh is necessary.
  • Official quick-start and example scripts for data preprocessing and training are provided.

Highlighted Details

  • Achieves 2.6x faster training for an 8B parameter model to reach baseline performance using a 280M proxy model.
  • Outputs optimized domain weights as JSON files.
  • Includes a fast, resumable dataloader with domain-level weighted sampling.
  • Integrates with HuggingFace Trainer and FlashAttention2.

Maintenance & Community

The project is associated with Sang Michael Xie and appears to be a research implementation. No specific community channels or roadmap are detailed in the README.

Licensing & Compatibility

The repository does not explicitly state a license. This may require clarification for commercial use or integration into closed-source projects.

Limitations & Caveats

The implementation focuses on single-node, multi-GPU training and does not support multi-node setups. There are noted differences from the original Google research paper, including PyTorch vs. JAX, subtle model architecture variations, and tokenizer differences, which may impact direct reproducibility. Gradient accumulation behavior is not equivalent to no accumulation due to stale domain weights.

Health Check
Last commit

1 year ago

Responsiveness

1+ week

Pull Requests (30d)
0
Issues (30d)
0
Star History
18 stars in the last 90 days

Explore Similar Projects

Feedback? Help us improve.