mesh-transformer-jax  by kingoflolz

JAX library for model-parallel transformers

Created 4 years ago
6,351 stars

Top 8.1% on SourcePulse

GitHubView on GitHub
Project Summary

This repository provides a model-parallel implementation of transformer language models in JAX, specifically designed for efficient scaling on TPUs. It offers a solution for researchers and practitioners needing to train and deploy large language models beyond the capacity of single devices, with a focus on the GPT-J-6B model.

How It Works

The library leverages JAX's xmap and pjit operators for model parallelism, employing a scheme similar to Megatron-LM. This approach partitions model weights and computations across multiple devices, optimizing communication for TPU's 2D mesh network. An experimental ZeRO-style sharding is also included for alternative parallelism strategies.

Quick Start & Requirements

  • Installation: Requires specific JAX versions: jax==0.2.12 and jaxlib==0.1.68 for v1 models (like GPT-J-6B). Newer JAX versions can be used for v2 model code.
  • Hardware: Primarily designed for TPUs, with scripts expecting to run on GCE VMs in the same region as TPUs to minimize latency. Some device_ scripts are limited to TPU v3-8. GPU usage is possible via checkpoint resharding.
  • Resources: GPT-J-6B weights are available in slim (9GB) and full (61GB) versions. Fine-tuning on a TPU v3-8 can achieve ~5000 tokens/second.
  • Links: GPT-J-6B Weights, Colab Demo, Web Demo, Fine-tuning Guide

Highlighted Details

  • Implements model parallelism for transformers using JAX xmap/pjit.
  • Includes GPT-J-6B, a 6 billion parameter autoregressive model trained on The Pile.
  • Achieves efficient scaling on TPUs via a Megatron-LM-like parallelism scheme.
  • Offers an experimental ZeRO-style sharding implementation.
  • Supports fine-tuning with high token throughput on TPUs.

Maintenance & Community

The project is authored by Ben Wang. Compute resources were provided by the TPU Research Cloud with assistance from EleutherAI.

Licensing & Compatibility

The GPT-J-6B weights are licensed under the Apache License, Version 2.0. This license is generally permissive for commercial use and closed-source linking.

Limitations & Caveats

The library is primarily optimized for TPUs, and running on other hardware may require significant adaptation. Specific JAX versions are required for compatibility with older models, which can be a maintenance burden. The README notes that strategies beyond 40B parameters may require different approaches, suggesting this library's scalability ceiling.

Health Check
Last Commit

2 years ago

Responsiveness

Inactive

Pull Requests (30d)
0
Issues (30d)
0
Star History
7 stars in the last 30 days

Explore Similar Projects

Starred by Jeff Hammerbacher Jeff Hammerbacher(Cofounder of Cloudera), Edward Sun Edward Sun(Research Scientist at Meta Superintelligence Lab), and
1 more.

jaxformer by salesforce

0.7%
301
JAX library for LLM training on TPUs
Created 3 years ago
Updated 1 year ago
Starred by Jeremy Howard Jeremy Howard(Cofounder of fast.ai) and Stas Bekman Stas Bekman(Author of "Machine Learning Engineering Open Book"; Research Engineer at Snowflake).

SwissArmyTransformer by THUDM

0.3%
1k
Transformer library for flexible model development
Created 4 years ago
Updated 8 months ago
Starred by Jiayi Pan Jiayi Pan(Author of SWE-Gym; MTS at xAI), Chip Huyen Chip Huyen(Author of "AI Engineering", "Designing Machine Learning Systems"), and
12 more.

EasyLM by young-geng

0.0%
2k
LLM training/finetuning framework in JAX/Flax
Created 2 years ago
Updated 1 year ago
Feedback? Help us improve.