numpyro  by pyro-ppl

Probabilistic programming library using JAX for GPU/TPU/CPU

Created 6 years ago
2,529 stars

Top 18.5% on SourcePulse

GitHubView on GitHub
Project Summary

NumPyro is a probabilistic programming library designed for researchers and practitioners who need efficient Bayesian inference. It leverages JAX for automatic differentiation and Just-In-Time (JIT) compilation, enabling accelerated computation on GPUs and TPUs. The library offers a familiar API for users of Pyro and PyTorch, simplifying the transition for those familiar with these ecosystems.

How It Works

NumPyro builds upon JAX's functional programming paradigm and its ability to compile Python and NumPy code into optimized kernels. This allows for significant speedups in complex probabilistic models, particularly for Hamiltonian Monte Carlo (HMC) and its variants like the No-U-Turn Sampler (NUTS). By composing JAX's jit and grad functionalities, NumPyro can compile entire inference steps, including gradient computations and tree-building, into highly efficient XLA-optimized code. It also provides a comprehensive suite of distributions and effect handlers for flexible model specification and custom inference algorithm development.

Quick Start & Requirements

  • Installation:
    • CPU: pip install numpyro
    • GPU: pip install 'numpyro[cuda]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    • Conda: conda install -c conda-forge numpyro
  • Prerequisites: JAX, NumPy. GPU support requires CUDA.
  • Resources: JAX's JIT compilation can significantly reduce CPU overhead. GPU/TPU acceleration is a key feature.
  • Docs: [https://num pyro.readthedocs.io/en/latest/](https://num pyro.readthedocs.io/en/latest/)
  • Examples: https://github.com/pyro-ppl/numpyro/tree/main/examples

Highlighted Details

  • JAX Backend: Leverages JAX for automatic differentiation, JIT compilation, and hardware acceleration (GPU/TPU).
  • MCMC Algorithms: Supports NUTS, HMC, MixedHMC, HMCECS, BarkerMH, HMCGibbs, DiscreteHMCGibbs, and SA.
  • Variational Inference: Includes ADVI with various automatic guides (AutoNormal, AutoDiagonalNormal, AutoMultivariateNormal, AutoLaplaceApproximation, etc.) and support for normalizing flows.
  • Reparameterization: Offers built-in reparameterization strategies (e.g., TransformReparam, LocScaleReparam) to improve MCMC convergence.
  • Enumeration: Supports parallel enumeration for models with discrete latent variables.

Maintenance & Community

  • Contributors: Active development with contributions from researchers in the probabilistic programming community.
  • Community: Forum available for discussions and support.
  • Roadmap: Focus on improving inference robustness, performance tuning, and expanding inference algorithm support.

Licensing & Compatibility

  • License: Apache License 2.0.
  • Compatibility: Permissive license allows for commercial use and integration with closed-source projects.

Limitations & Caveats

NumPyro is under active development, which may lead to API brittleness and bugs. Windows support is limited and may require building JAXlib from source or using the Windows Subsystem for Linux. Models may need to be rewritten in a more functional style to fully leverage JAX's JIT compilation.

Health Check
Last Commit

1 week ago

Responsiveness

1 day

Pull Requests (30d)
6
Issues (30d)
5
Star History
22 stars in the last 30 days

Explore Similar Projects

Starred by Edward Sun Edward Sun(Research Scientist at Meta Superintelligence Lab), Luca Soldaini Luca Soldaini(Research Scientist at Ai2), and
3 more.

transformers-bloom-inference by huggingface

0%
565
Inference solutions for BLOOM models
Created 3 years ago
Updated 11 months ago
Starred by Andrej Karpathy Andrej Karpathy(Founder of Eureka Labs; Formerly at Tesla, OpenAI; Author of CS 231n), Jeff Hammerbacher Jeff Hammerbacher(Cofounder of Cloudera), and
4 more.

gemma_pytorch by google

0.2%
6k
PyTorch implementation for Google's Gemma models
Created 1 year ago
Updated 3 months ago
Starred by Yineng Zhang Yineng Zhang(Inference Lead at SGLang; Research Scientist at Together AI), Lewis Tunstall Lewis Tunstall(Research Engineer at Hugging Face), and
15 more.

torchtune by pytorch

0.2%
5k
PyTorch library for LLM post-training and experimentation
Created 1 year ago
Updated 1 day ago
Feedback? Help us improve.