PyTorch module for memory-efficient cross-entropy in LLMs
Top 62.4% on sourcepulse
This library provides Cut Cross-Entropy (CCE), a memory-efficient method for computing the cross-entropy loss in large-vocabulary language models. It targets researchers and engineers working with LLMs, offering significant memory reductions during training without compromising speed or convergence.
How It Works
CCE avoids materializing the full logit matrix by computing only the logit for the correct token and performing a log-sum-exp reduction on-the-fly. This is achieved via a custom Triton kernel that leverages flash memory for matrix multiplications and reductions, drastically reducing global memory consumption. Gradient computation is further optimized by skipping negligible contributions, improving throughput.
Quick Start & Requirements
pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git"
torch.compile
fallback is available for unsupported systems (e.g., macOS).from cut_cross_entropy import linear_cross_entropy
Highlighted Details
cce_patch
for Llama, Phi3, Mistral, and Gemma2 families.cce
, torch_compile
, cce_kahan
, cce_kahan_full_c
, cce_exact
) for different precision and performance needs.Maintenance & Community
The project is from Apple. Further community engagement details are not explicitly provided in the README.
Licensing & Compatibility
The code is released under "LICENSE terms," which are not specified in the README. Compatibility for commercial use or closed-source linking is not detailed.
Limitations & Caveats
The primary CCE implementation requires an Ampere or newer GPU. While a torch.compile
fallback exists, its performance characteristics may differ. The exact license terms are not specified, which could impact commercial adoption.
5 days ago
1 week