paxml  by google

Jax-based ML framework for large-scale model training and experimentation

created 3 years ago
518 stars

Top 61.5% on sourcepulse

GitHubView on GitHub
Project Summary

Paxml (Pax) is a Jax-based framework for configuring and running large-scale machine learning experiments, targeting researchers and engineers who need advanced parallelization and high flop utilization. It provides a flexible system for defining models, tasks, and data pipelines, enabling efficient training of massive models on distributed hardware.

How It Works

Pax leverages Jax for its automatic differentiation and XLA compilation, enabling high performance on accelerators like TPUs. It uses a Pythonic dataclass-based configuration system (via Fiddle) for hyperparameters, allowing for nested configurations and shared layer instances. Models are built using a layer-based architecture, inheriting from Flax's nn.Module, with clear separation of concerns for weight creation (setup) and forward propagation (fprop). Data handling is designed for multihost environments, ensuring proper sharding and avoiding duplicate batches.

Quick Start & Requirements

  • Installation: pip install -U pip && pip install paxml jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html (for stable release) or pip install -e praxis && pip install -e paxml (for dev version).
  • Prerequisites: Google Cloud TPU VM, gcloud CLI.
  • Documentation: Paxml Docs and Jupyter Notebook tutorials.
  • Example Runs: Provided for various model sizes (1B, 16B, GPT3-XL) on Cloud TPU v4.

Highlighted Details

  • Demonstrated industry-leading Model FLOPs Utilization (MFU) rates on TPU v4.
  • Supports advanced parallelization strategies including SPMD and multislice configurations.
  • Interoperability with Fiddle for configuration management and eager error checking.
  • Flexible data input system supporting SeqIO, Lingvo, and custom data pipelines.
  • Includes detailed mappings for migrating from MaxText configurations.

Maintenance & Community

  • Developed by Google.
  • Version 0.1.0 released with commit 546370f5323ef8b27d38ddc32445d7d3d1e4da9a.
  • NVIDIA has a separate GPU-optimized version available via the Rosetta repository.

Licensing & Compatibility

  • Licensed under the Apache License, Version 2.0.
  • Permissive license suitable for commercial use and integration with closed-source projects.

Limitations & Caveats

The primary focus and setup instructions are heavily geared towards Google Cloud TPUs. While NVIDIA has released a GPU-optimized version, the core Paxml repository's GPU support and performance characteristics may differ. The README mentions specific requirements.txt files for stable releases, suggesting potential dependency management complexities for exact reproducibility.

Health Check
Last commit

1 day ago

Responsiveness

Inactive

Pull Requests (30d)
0
Issues (30d)
0
Star History
31 stars in the last 90 days

Explore Similar Projects

Feedback? Help us improve.