tunix  by google

JAX-native library for efficient LLM post-training

Created 11 months ago
2,182 stars

Top 20.3% on SourcePulse

GitHubView on GitHub
Project Summary

A JAX-native library for Large Language Model (LLM) post-training, Tunix streamlines supervised fine-tuning, reinforcement learning (RL), and knowledge distillation. It targets researchers and engineers seeking efficient, scalable LLM adaptation on accelerators, leveraging JAX for high-performance computation and seamless integration with Flax NNX. The library aims to simplify complex post-training workflows, offering a modular and extensible framework.

How It Works

Tunix is built upon JAX, enabling accelerated, distributed computation, particularly on TPUs. It supports various post-training methodologies, including parameter-efficient fine-tuning (PEFT) via LoRA/Q-LoRA, multiple RL algorithms like PPO, GRPO, and GSPO-token, and preference alignment through Direct Preference Optimization (DPO). For knowledge distillation, it offers strategies such as matching output probability distributions (Logit Strategy), aligning attention mechanisms, and matching intermediate feature representations. The architecture emphasizes modularity for customization and efficiency, with native support for common model sharding strategies (DP, FSDP, TP) designed for multi-host distributed training.

Quick Start & Requirements

  • Installation: Recommended: pip install "tunix[prod]". Latest from GitHub: pip install git+https://github.com/google/tunix. Editable install for development: git clone https://github.com/google/tunix.git && cd tunix && pip install -e ".[dev]".
  • Prerequisites: JAX, Flax NNX. Optimized for TPUs and distributed training on accelerators.
  • Resources: Detailed examples and tutorials are available. A setup script for Jupyter notebooks on single-host GCP TPU VMs is provided. Official documentation and more examples are planned.

Highlighted Details

  • Supports full weights fine-tuning and PEFT methods like LoRA/Q-LoRA.
  • Implements RL algorithms including PPO, GRPO, and token-level GSPO.
  • Includes Direct Preference Optimization (DPO) for preference alignment.
  • Offers diverse knowledge distillation techniques: Logit, Attention Transfer, and Feature Pooling.
  • Designed for efficient distributed training with native support for DP, FSDP, and TP sharding on accelerators.
  • Upcoming features include Agentic RL, advanced algorithms, and vLLM integration for optimized rollouts.

Maintenance & Community

Tunix is in "Early Development," with active expansion of capabilities and performance improvements planned. Contributions are welcomed, with a draft contribution process available. Users can engage via the Tunix GitHub discussion forum for feature requests, issues, and questions. A notable collaboration with GRL (Game Reinforcement Learning) integrates seamless TPU support for scalable RL experiments.

Licensing & Compatibility

The provided README does not specify a software license. This omission requires further investigation for compatibility with commercial or closed-source projects.

Limitations & Caveats

The library is in early development, indicating ongoing work on features, usability, and performance optimization. The contribution process is still being formalized. Specific limitations regarding unsupported platforms or known bugs are not detailed in the provided text.

Health Check
Last Commit

19 hours ago

Responsiveness

Inactive

Pull Requests (30d)
162
Issues (30d)
9
Star History
42 stars in the last 30 days

Explore Similar Projects

Starred by Chip Huyen Chip Huyen(Author of "AI Engineering", "Designing Machine Learning Systems"), Wing Lian Wing Lian(Founder of Axolotl AI), and
3 more.

ROLL by alibaba

1.4%
3k
RL library for large language models
Created 9 months ago
Updated 22 hours ago
Feedback? Help us improve.