tunix  by google

JAX-native library for efficient LLM post-training

Created 6 months ago
1,652 stars

Top 25.4% on SourcePulse

GitHubView on GitHub
Project Summary

A JAX-native library for Large Language Model (LLM) post-training, Tunix streamlines supervised fine-tuning, reinforcement learning (RL), and knowledge distillation. It targets researchers and engineers seeking efficient, scalable LLM adaptation on accelerators, leveraging JAX for high-performance computation and seamless integration with Flax NNX. The library aims to simplify complex post-training workflows, offering a modular and extensible framework.

How It Works

Tunix is built upon JAX, enabling accelerated, distributed computation, particularly on TPUs. It supports various post-training methodologies, including parameter-efficient fine-tuning (PEFT) via LoRA/Q-LoRA, multiple RL algorithms like PPO, GRPO, and GSPO-token, and preference alignment through Direct Preference Optimization (DPO). For knowledge distillation, it offers strategies such as matching output probability distributions (Logit Strategy), aligning attention mechanisms, and matching intermediate feature representations. The architecture emphasizes modularity for customization and efficiency, with native support for common model sharding strategies (DP, FSDP, TP) designed for multi-host distributed training.

Quick Start & Requirements

  • Installation: Recommended: pip install "tunix[prod]". Latest from GitHub: pip install git+https://github.com/google/tunix. Editable install for development: git clone https://github.com/google/tunix.git && cd tunix && pip install -e ".[dev]".
  • Prerequisites: JAX, Flax NNX. Optimized for TPUs and distributed training on accelerators.
  • Resources: Detailed examples and tutorials are available. A setup script for Jupyter notebooks on single-host GCP TPU VMs is provided. Official documentation and more examples are planned.

Highlighted Details

  • Supports full weights fine-tuning and PEFT methods like LoRA/Q-LoRA.
  • Implements RL algorithms including PPO, GRPO, and token-level GSPO.
  • Includes Direct Preference Optimization (DPO) for preference alignment.
  • Offers diverse knowledge distillation techniques: Logit, Attention Transfer, and Feature Pooling.
  • Designed for efficient distributed training with native support for DP, FSDP, and TP sharding on accelerators.
  • Upcoming features include Agentic RL, advanced algorithms, and vLLM integration for optimized rollouts.

Maintenance & Community

Tunix is in "Early Development," with active expansion of capabilities and performance improvements planned. Contributions are welcomed, with a draft contribution process available. Users can engage via the Tunix GitHub discussion forum for feature requests, issues, and questions. A notable collaboration with GRL (Game Reinforcement Learning) integrates seamless TPU support for scalable RL experiments.

Licensing & Compatibility

The provided README does not specify a software license. This omission requires further investigation for compatibility with commercial or closed-source projects.

Limitations & Caveats

The library is in early development, indicating ongoing work on features, usability, and performance optimization. The contribution process is still being formalized. Specific limitations regarding unsupported platforms or known bugs are not detailed in the provided text.

Health Check
Last Commit

10 hours ago

Responsiveness

Inactive

Pull Requests (30d)
245
Issues (30d)
14
Star History
1,594 stars in the last 30 days

Explore Similar Projects

Starred by Casper Hansen Casper Hansen(Author of AutoAWQ), Yineng Zhang Yineng Zhang(Inference Lead at SGLang; Research Scientist at Together AI), and
5 more.

xtuner by InternLM

0.3%
5k
LLM fine-tuning toolkit for research
Created 2 years ago
Updated 9 hours ago
Feedback? Help us improve.