Jax-based ML framework for large-scale model training and experimentation
Top 61.5% on sourcepulse
Paxml (Pax) is a Jax-based framework for configuring and running large-scale machine learning experiments, targeting researchers and engineers who need advanced parallelization and high flop utilization. It provides a flexible system for defining models, tasks, and data pipelines, enabling efficient training of massive models on distributed hardware.
How It Works
Pax leverages Jax for its automatic differentiation and XLA compilation, enabling high performance on accelerators like TPUs. It uses a Pythonic dataclass-based configuration system (via Fiddle) for hyperparameters, allowing for nested configurations and shared layer instances. Models are built using a layer-based architecture, inheriting from Flax's nn.Module
, with clear separation of concerns for weight creation (setup
) and forward propagation (fprop
). Data handling is designed for multihost environments, ensuring proper sharding and avoiding duplicate batches.
Quick Start & Requirements
pip install -U pip && pip install paxml jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
(for stable release) or pip install -e praxis && pip install -e paxml
(for dev version).gcloud
CLI.Highlighted Details
Maintenance & Community
546370f5323ef8b27d38ddc32445d7d3d1e4da9a
.Licensing & Compatibility
Limitations & Caveats
The primary focus and setup instructions are heavily geared towards Google Cloud TPUs. While NVIDIA has released a GPU-optimized version, the core Paxml repository's GPU support and performance characteristics may differ. The README mentions specific requirements.txt
files for stable releases, suggesting potential dependency management complexities for exact reproducibility.
1 day ago
Inactive