jaxformer  by salesforce

JAX library for LLM training on TPUs

Created 3 years ago
301 stars

Top 88.6% on SourcePulse

GitHubView on GitHub
Project Summary

Jaxformer is a minimal JAX library designed for training large language models (LLMs) on TPUs, leveraging the pjit() operator for efficient data and model parallelism. It targets researchers and engineers working with LLMs on TPU infrastructure, offering a streamlined approach to distributed training and fine-tuning.

How It Works

Jaxformer utilizes JAX's pjit() for advanced parallelism strategies, including data and model sharding across TPU pods. It implements a push-based TCP/IP protocol for inter-worker communication and orchestration, enabling efficient scaling for models up to 6 billion parameters. The library also supports xmap() emulation via pjit() sharding and includes features like distributed checkpointing and scan() for optimized JIT compilation.

Quick Start & Requirements

  • Installation: pip install -r requirements.txt (after cloning the repository).
  • Prerequisites: Python 3.9+, JAX with TPU support (jax[tpu]), Google Cloud SDK, and potentially specific CUDA versions for A100 fine-tuning. TPU hardware (v3/v4) is the primary target.
  • Setup: Requires provisioning TPUs and configuring cloud credentials. Detailed instructions are provided for local CPU, local TPU, and remote TPU setups.
  • Resources: Training LLMs on TPUs requires significant computational resources.
  • Links: GitHub Repository

Highlighted Details

  • Supports training and fine-tuning of models like CodeGen and ProGen2 on TPU v3/v4.
  • Features Megatron-style pjit() sharding for efficient model parallelism.
  • Includes stateful, resumable data loading using TFRecords.
  • Offers distributed checkpointing with full state recovery.

Maintenance & Community

The project is associated with Salesforce. The README mentions contributors like Ben Wang, James Bradbury, and Zak Stone. There are no explicit links to community channels like Discord or Slack provided in the README.

Licensing & Compatibility

The README does not explicitly state a license. However, given its association with Salesforce and the nature of the code, it is likely intended for research and internal use. Compatibility with commercial or closed-source projects is not specified.

Limitations & Caveats

The library is described as "minimal" and primarily targets TPU hardware, with limited explicit support or instructions for other accelerators like GPUs (though an A100 fine-tuning example is present). The setup process, especially for TPUs, can be complex and requires familiarity with Google Cloud.

Health Check
Last Commit

1 year ago

Responsiveness

Inactive

Pull Requests (30d)
0
Issues (30d)
0
Star History
5 stars in the last 30 days

Explore Similar Projects

Starred by Jiayi Pan Jiayi Pan(Author of SWE-Gym; MTS at xAI), Chip Huyen Chip Huyen(Author of "AI Engineering", "Designing Machine Learning Systems"), and
12 more.

EasyLM by young-geng

0.0%
2k
LLM training/finetuning framework in JAX/Flax
Created 2 years ago
Updated 1 year ago
Feedback? Help us improve.