maxtext  by AI-Hypercomputer

Jax LLM for high-performance, scalable training/inference on TPUs/GPUs

Created 2 years ago
1,948 stars

Top 22.5% on SourcePulse

GitHubView on GitHub
Project Summary

MaxText is a high-performance, scalable LLM framework built in pure JAX, designed for training and inference on Google Cloud TPUs and GPUs. It targets researchers and production engineers seeking a performant, flexible starting point for large-scale LLM projects, offering high Model FLOPs Utilization (MFU) and ease of customization through forking.

How It Works

MaxText leverages JAX and the XLA compiler to achieve high performance and scalability without manual kernel optimization. This approach allows for efficient execution across large TPU and GPU clusters, simplifying the development process by relying on XLA's automatic optimization capabilities. The framework supports various LLM architectures and offers features like ahead-of-time (AOT) compilation for faster startup and debugging tools for cluster issues.

Quick Start & Requirements

  • Install: python3 -m MaxText.train (after setting up environment and dependencies).
  • Prerequisites: JAX with TPU support (jax[tpu]), Python, and potentially specific hardware configurations (TPUs or GPUs). A setup.sh script is provided for dependency installation.
  • Resources: Requires access to Google Cloud TPUs or GPUs. AOT compilation can be performed on a single machine.
  • Docs: Getting Started, Gemma Guide, Llama2 Guide, Mixtral Guide, DeepSeek Guide.

Highlighted Details

  • Supports Llama 2/3/4, Mistral/Mixtral, Gemma 1-3, and DeepSeek families.
  • Achieves high MFU (e.g., 60-70% on TPU v5p) and scales to thousands of chips.
  • Features Ahead-of-Time (AOT) compilation for optimized training runs.
  • Includes stack trace collection for debugging distributed training issues.

Maintenance & Community

  • Actively updated with new model support (e.g., Llama 4, Gemma 3).
  • Modular import structure introduced in April 2025.
  • No explicit community links (Discord/Slack) are mentioned in the README.

Licensing & Compatibility

  • The README does not explicitly state a license.

Limitations & Caveats

  • Currently supports text-only models; multi-modal support is in development.
  • Context length is limited to 8k for some models, with ongoing optimization efforts.
  • AOT compilation requires matching compilation and execution environments for predictable behavior.
Health Check
Last Commit

7 hours ago

Responsiveness

1 day

Pull Requests (30d)
175
Issues (30d)
15
Star History
32 stars in the last 30 days

Explore Similar Projects

Starred by Jeff Hammerbacher Jeff Hammerbacher(Cofounder of Cloudera), Edward Sun Edward Sun(Research Scientist at Meta Superintelligence Lab), and
1 more.

jaxformer by salesforce

0%
298
JAX library for LLM training on TPUs
Created 3 years ago
Updated 1 year ago
Starred by Jiayi Pan Jiayi Pan(Author of SWE-Gym; MTS at xAI), Chip Huyen Chip Huyen(Author of "AI Engineering", "Designing Machine Learning Systems"), and
12 more.

EasyLM by young-geng

0%
2k
LLM training/finetuning framework in JAX/Flax
Created 2 years ago
Updated 1 year ago
Starred by Lianmin Zheng Lianmin Zheng(Coauthor of SGLang, vLLM), Chip Huyen Chip Huyen(Author of "AI Engineering", "Designing Machine Learning Systems"), and
1 more.

MiniCPM by OpenBMB

0.1%
8k
Ultra-efficient LLMs for end devices, achieving 5x+ speedup
Created 1 year ago
Updated 3 weeks ago
Starred by Andrej Karpathy Andrej Karpathy(Founder of Eureka Labs; Formerly at Tesla, OpenAI; Author of CS 231n), Stefan van der Walt Stefan van der Walt(Core Contributor to scientific Python ecosystem), and
12 more.

litgpt by Lightning-AI

0.1%
13k
LLM SDK for pretraining, finetuning, and deploying 20+ high-performance LLMs
Created 2 years ago
Updated 3 days ago
Feedback? Help us improve.