ml-cross-entropy  by apple

PyTorch module for memory-efficient cross-entropy in LLMs

created 8 months ago
506 stars

Top 62.4% on sourcepulse

GitHubView on GitHub
Project Summary

This library provides Cut Cross-Entropy (CCE), a memory-efficient method for computing the cross-entropy loss in large-vocabulary language models. It targets researchers and engineers working with LLMs, offering significant memory reductions during training without compromising speed or convergence.

How It Works

CCE avoids materializing the full logit matrix by computing only the logit for the correct token and performing a log-sum-exp reduction on-the-fly. This is achieved via a custom Triton kernel that leverages flash memory for matrix multiplications and reductions, drastically reducing global memory consumption. Gradient computation is further optimized by skipping negligible contributions, improving throughput.

Quick Start & Requirements

  • Install: pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git"
  • Requirements: Python 3.10+, PyTorch 2.4+, Triton 3.0+, Ampere (or newer) GPU. A torch.compile fallback is available for unsupported systems (e.g., macOS).
  • Usage: from cut_cross_entropy import linear_cross_entropy
  • Docs: https://github.com/apple/ml-cross-entropy

Highlighted Details

  • Reduces loss computation memory from 24 GB to 1 MB for Gemma 2 (2B).
  • Supports vocabulary parallelism for sharded classifier weights.
  • Integrates with Hugging Face Transformers via cce_patch for Llama, Phi3, Mistral, and Gemma2 families.
  • Offers multiple implementations (cce, torch_compile, cce_kahan, cce_kahan_full_c, cce_exact) for different precision and performance needs.

Maintenance & Community

The project is from Apple. Further community engagement details are not explicitly provided in the README.

Licensing & Compatibility

The code is released under "LICENSE terms," which are not specified in the README. Compatibility for commercial use or closed-source linking is not detailed.

Limitations & Caveats

The primary CCE implementation requires an Ampere or newer GPU. While a torch.compile fallback exists, its performance characteristics may differ. The exact license terms are not specified, which could impact commercial adoption.

Health Check
Last commit

5 days ago

Responsiveness

1 week

Pull Requests (30d)
4
Issues (30d)
2
Star History
75 stars in the last 90 days

Explore Similar Projects

Starred by Chip Huyen Chip Huyen(Author of AI Engineering, Designing Machine Learning Systems), Jaret Burkett Jaret Burkett(Founder of Ostris), and
1 more.

nunchaku by nunchaku-tech

2.1%
3k
High-performance 4-bit diffusion model inference engine
created 8 months ago
updated 14 hours ago
Starred by Andrej Karpathy Andrej Karpathy(Founder of Eureka Labs; Formerly at Tesla, OpenAI; Author of CS 231n), Jiayi Pan Jiayi Pan(Author of SWE-Gym; AI Researcher at UC Berkeley), and
5 more.

Liger-Kernel by linkedin

0.6%
5k
Triton kernels for efficient LLM training
created 1 year ago
updated 1 day ago
Feedback? Help us improve.