paxml  by google

Jax-based ML framework for large-scale model training and experimentation

Created 3 years ago
536 stars

Top 59.4% on SourcePulse

GitHubView on GitHub
Project Summary

Paxml (Pax) is a Jax-based framework for configuring and running large-scale machine learning experiments, targeting researchers and engineers who need advanced parallelization and high flop utilization. It provides a flexible system for defining models, tasks, and data pipelines, enabling efficient training of massive models on distributed hardware.

How It Works

Pax leverages Jax for its automatic differentiation and XLA compilation, enabling high performance on accelerators like TPUs. It uses a Pythonic dataclass-based configuration system (via Fiddle) for hyperparameters, allowing for nested configurations and shared layer instances. Models are built using a layer-based architecture, inheriting from Flax's nn.Module, with clear separation of concerns for weight creation (setup) and forward propagation (fprop). Data handling is designed for multihost environments, ensuring proper sharding and avoiding duplicate batches.

Quick Start & Requirements

  • Installation: pip install -U pip && pip install paxml jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html (for stable release) or pip install -e praxis && pip install -e paxml (for dev version).
  • Prerequisites: Google Cloud TPU VM, gcloud CLI.
  • Documentation: Paxml Docs and Jupyter Notebook tutorials.
  • Example Runs: Provided for various model sizes (1B, 16B, GPT3-XL) on Cloud TPU v4.

Highlighted Details

  • Demonstrated industry-leading Model FLOPs Utilization (MFU) rates on TPU v4.
  • Supports advanced parallelization strategies including SPMD and multislice configurations.
  • Interoperability with Fiddle for configuration management and eager error checking.
  • Flexible data input system supporting SeqIO, Lingvo, and custom data pipelines.
  • Includes detailed mappings for migrating from MaxText configurations.

Maintenance & Community

  • Developed by Google.
  • Version 0.1.0 released with commit 546370f5323ef8b27d38ddc32445d7d3d1e4da9a.
  • NVIDIA has a separate GPU-optimized version available via the Rosetta repository.

Licensing & Compatibility

  • Licensed under the Apache License, Version 2.0.
  • Permissive license suitable for commercial use and integration with closed-source projects.

Limitations & Caveats

The primary focus and setup instructions are heavily geared towards Google Cloud TPUs. While NVIDIA has released a GPU-optimized version, the core Paxml repository's GPU support and performance characteristics may differ. The README mentions specific requirements.txt files for stable releases, suggesting potential dependency management complexities for exact reproducibility.

Health Check
Last Commit

2 weeks ago

Responsiveness

1 week

Pull Requests (30d)
0
Issues (30d)
0
Star History
10 stars in the last 30 days

Explore Similar Projects

Starred by Edward Sun Edward Sun(Research Scientist at Meta Superintelligence Lab), Phil Wang Phil Wang(Prolific Research Paper Implementer), and
1 more.

grain by google

0.9%
536
Python library for ML training data pipelines
Created 3 years ago
Updated 1 day ago
Starred by Chip Huyen Chip Huyen(Author of "AI Engineering", "Designing Machine Learning Systems"), Jiayi Pan Jiayi Pan(Author of SWE-Gym; MTS at xAI), and
20 more.

alpa by alpa-projects

0.0%
3k
Auto-parallelization framework for large-scale neural network training and serving
Created 4 years ago
Updated 1 year ago
Feedback? Help us improve.