PyTorch for optimizing data mixtures in language model datasets
Top 82.2% on sourcepulse
DoReMi provides a PyTorch implementation for optimizing data mixture weights in language modeling datasets, targeting researchers and practitioners aiming to improve LLM training efficiency. It addresses the challenge of determining optimal domain proportions for training data by using Distributionally Robust Optimization (DRO) to create a data mixture that is robust across various downstream tasks, leading to faster convergence and improved performance.
How It Works
DoReMi employs a minimax optimization strategy. It trains a small proxy model using DRO, which dynamically adjusts domain weights based on the proxy model's excess loss relative to a pretrained reference model. This reference model anchors the optimization, preventing excessive pessimism for high-entropy domains. The resulting optimized data mixture is then used to train a larger language model, achieving significant speedups and performance gains.
Quick Start & Requirements
pip install -e doremi
after cloning the repository.bash scripts/setup_flash.sh
) can take hours.constants.sh
is necessary.Highlighted Details
Maintenance & Community
The project is associated with Sang Michael Xie and appears to be a research implementation. No specific community channels or roadmap are detailed in the README.
Licensing & Compatibility
The repository does not explicitly state a license. This may require clarification for commercial use or integration into closed-source projects.
Limitations & Caveats
The implementation focuses on single-node, multi-GPU training and does not support multi-node setups. There are noted differences from the original Google research paper, including PyTorch vs. JAX, subtle model architecture variations, and tokenizer differences, which may impact direct reproducibility. Gradient accumulation behavior is not equivalent to no accumulation due to stale domain weights.
1 year ago
1+ week