jaxformer  by salesforce

JAX library for LLM training on TPUs

created 2 years ago
292 stars

Top 91.4% on sourcepulse

GitHubView on GitHub
1 Expert Loves This Project
Project Summary

Jaxformer is a minimal JAX library designed for training large language models (LLMs) on TPUs, leveraging the pjit() operator for efficient data and model parallelism. It targets researchers and engineers working with LLMs on TPU infrastructure, offering a streamlined approach to distributed training and fine-tuning.

How It Works

Jaxformer utilizes JAX's pjit() for advanced parallelism strategies, including data and model sharding across TPU pods. It implements a push-based TCP/IP protocol for inter-worker communication and orchestration, enabling efficient scaling for models up to 6 billion parameters. The library also supports xmap() emulation via pjit() sharding and includes features like distributed checkpointing and scan() for optimized JIT compilation.

Quick Start & Requirements

  • Installation: pip install -r requirements.txt (after cloning the repository).
  • Prerequisites: Python 3.9+, JAX with TPU support (jax[tpu]), Google Cloud SDK, and potentially specific CUDA versions for A100 fine-tuning. TPU hardware (v3/v4) is the primary target.
  • Setup: Requires provisioning TPUs and configuring cloud credentials. Detailed instructions are provided for local CPU, local TPU, and remote TPU setups.
  • Resources: Training LLMs on TPUs requires significant computational resources.
  • Links: GitHub Repository

Highlighted Details

  • Supports training and fine-tuning of models like CodeGen and ProGen2 on TPU v3/v4.
  • Features Megatron-style pjit() sharding for efficient model parallelism.
  • Includes stateful, resumable data loading using TFRecords.
  • Offers distributed checkpointing with full state recovery.

Maintenance & Community

The project is associated with Salesforce. The README mentions contributors like Ben Wang, James Bradbury, and Zak Stone. There are no explicit links to community channels like Discord or Slack provided in the README.

Licensing & Compatibility

The README does not explicitly state a license. However, given its association with Salesforce and the nature of the code, it is likely intended for research and internal use. Compatibility with commercial or closed-source projects is not specified.

Limitations & Caveats

The library is described as "minimal" and primarily targets TPU hardware, with limited explicit support or instructions for other accelerators like GPUs (though an A100 fine-tuning example is present). The setup process, especially for TPUs, can be complex and requires familiarity with Google Cloud.

Health Check
Last commit

1 year ago

Responsiveness

Inactive

Pull Requests (30d)
0
Issues (30d)
0
Star History
7 stars in the last 90 days

Explore Similar Projects

Starred by Jiayi Pan Jiayi Pan(Author of SWE-Gym; AI Researcher at UC Berkeley), Philipp Schmid Philipp Schmid(DevRel at Google DeepMind), and
2 more.

tpu-starter by ayaka14732

0%
536
TPU guide for JAX-based ML workflows on Google Cloud
created 3 years ago
updated 1 year ago
Starred by Stas Bekman Stas Bekman(Author of Machine Learning Engineering Open Book; Research Engineer at Snowflake) and Zhiqiang Xie Zhiqiang Xie(Author of SGLang).

veScale by volcengine

0.1%
839
PyTorch-native framework for LLM training
created 1 year ago
updated 3 weeks ago
Starred by Jiayi Pan Jiayi Pan(Author of SWE-Gym; AI Researcher at UC Berkeley), Thomas Wolf Thomas Wolf(Cofounder of Hugging Face), and
3 more.

levanter by stanford-crfm

0.5%
628
Framework for training foundation models with JAX
created 3 years ago
updated 21 hours ago
Feedback? Help us improve.