mesh-transformer-jax  by kingoflolz

JAX library for model-parallel transformers

created 4 years ago
6,349 stars

Top 8.2% on sourcepulse

GitHubView on GitHub
Project Summary

This repository provides a model-parallel implementation of transformer language models in JAX, specifically designed for efficient scaling on TPUs. It offers a solution for researchers and practitioners needing to train and deploy large language models beyond the capacity of single devices, with a focus on the GPT-J-6B model.

How It Works

The library leverages JAX's xmap and pjit operators for model parallelism, employing a scheme similar to Megatron-LM. This approach partitions model weights and computations across multiple devices, optimizing communication for TPU's 2D mesh network. An experimental ZeRO-style sharding is also included for alternative parallelism strategies.

Quick Start & Requirements

  • Installation: Requires specific JAX versions: jax==0.2.12 and jaxlib==0.1.68 for v1 models (like GPT-J-6B). Newer JAX versions can be used for v2 model code.
  • Hardware: Primarily designed for TPUs, with scripts expecting to run on GCE VMs in the same region as TPUs to minimize latency. Some device_ scripts are limited to TPU v3-8. GPU usage is possible via checkpoint resharding.
  • Resources: GPT-J-6B weights are available in slim (9GB) and full (61GB) versions. Fine-tuning on a TPU v3-8 can achieve ~5000 tokens/second.
  • Links: GPT-J-6B Weights, Colab Demo, Web Demo, Fine-tuning Guide

Highlighted Details

  • Implements model parallelism for transformers using JAX xmap/pjit.
  • Includes GPT-J-6B, a 6 billion parameter autoregressive model trained on The Pile.
  • Achieves efficient scaling on TPUs via a Megatron-LM-like parallelism scheme.
  • Offers an experimental ZeRO-style sharding implementation.
  • Supports fine-tuning with high token throughput on TPUs.

Maintenance & Community

The project is authored by Ben Wang. Compute resources were provided by the TPU Research Cloud with assistance from EleutherAI.

Licensing & Compatibility

The GPT-J-6B weights are licensed under the Apache License, Version 2.0. This license is generally permissive for commercial use and closed-source linking.

Limitations & Caveats

The library is primarily optimized for TPUs, and running on other hardware may require significant adaptation. Specific JAX versions are required for compatibility with older models, which can be a maintenance burden. The README notes that strategies beyond 40B parameters may require different approaches, suggesting this library's scalability ceiling.

Health Check
Last commit

2 years ago

Responsiveness

1 day

Pull Requests (30d)
0
Issues (30d)
0
Star History
30 stars in the last 90 days

Explore Similar Projects

Starred by Chip Huyen Chip Huyen(Author of AI Engineering, Designing Machine Learning Systems), Jiayi Pan Jiayi Pan(Author of SWE-Gym; AI Researcher at UC Berkeley), and
11 more.

alpa by alpa-projects

0.1%
3k
Auto-parallelization framework for large-scale neural network training and serving
created 4 years ago
updated 1 year ago
Feedback? Help us improve.