JAX library for LLM training on TPUs
Top 91.4% on sourcepulse
Jaxformer is a minimal JAX library designed for training large language models (LLMs) on TPUs, leveraging the pjit()
operator for efficient data and model parallelism. It targets researchers and engineers working with LLMs on TPU infrastructure, offering a streamlined approach to distributed training and fine-tuning.
How It Works
Jaxformer utilizes JAX's pjit()
for advanced parallelism strategies, including data and model sharding across TPU pods. It implements a push-based TCP/IP protocol for inter-worker communication and orchestration, enabling efficient scaling for models up to 6 billion parameters. The library also supports xmap()
emulation via pjit()
sharding and includes features like distributed checkpointing and scan()
for optimized JIT compilation.
Quick Start & Requirements
pip install -r requirements.txt
(after cloning the repository).jax[tpu]
), Google Cloud SDK, and potentially specific CUDA versions for A100 fine-tuning. TPU hardware (v3/v4) is the primary target.Highlighted Details
pjit()
sharding for efficient model parallelism.Maintenance & Community
The project is associated with Salesforce. The README mentions contributors like Ben Wang, James Bradbury, and Zak Stone. There are no explicit links to community channels like Discord or Slack provided in the README.
Licensing & Compatibility
The README does not explicitly state a license. However, given its association with Salesforce and the nature of the code, it is likely intended for research and internal use. Compatibility with commercial or closed-source projects is not specified.
Limitations & Caveats
The library is described as "minimal" and primarily targets TPU hardware, with limited explicit support or instructions for other accelerators like GPUs (though an A100 fine-tuning example is present). The setup process, especially for TPUs, can be complex and requires familiarity with Google Cloud.
1 year ago
Inactive