JAX library for model-parallel transformers
Top 8.2% on sourcepulse
This repository provides a model-parallel implementation of transformer language models in JAX, specifically designed for efficient scaling on TPUs. It offers a solution for researchers and practitioners needing to train and deploy large language models beyond the capacity of single devices, with a focus on the GPT-J-6B model.
How It Works
The library leverages JAX's xmap
and pjit
operators for model parallelism, employing a scheme similar to Megatron-LM. This approach partitions model weights and computations across multiple devices, optimizing communication for TPU's 2D mesh network. An experimental ZeRO-style sharding is also included for alternative parallelism strategies.
Quick Start & Requirements
jax==0.2.12
and jaxlib==0.1.68
for v1 models (like GPT-J-6B). Newer JAX versions can be used for v2 model code.device_
scripts are limited to TPU v3-8. GPU usage is possible via checkpoint resharding.Highlighted Details
xmap
/pjit
.Maintenance & Community
The project is authored by Ben Wang. Compute resources were provided by the TPU Research Cloud with assistance from EleutherAI.
Licensing & Compatibility
The GPT-J-6B weights are licensed under the Apache License, Version 2.0. This license is generally permissive for commercial use and closed-source linking.
Limitations & Caveats
The library is primarily optimized for TPUs, and running on other hardware may require significant adaptation. Specific JAX versions are required for compatibility with older models, which can be a maintenance burden. The README notes that strategies beyond 40B parameters may require different approaches, suggesting this library's scalability ceiling.
2 years ago
1 day