numpyro  by pyro-ppl

Probabilistic programming library using JAX for GPU/TPU/CPU

created 6 years ago
2,502 stars

Top 19.1% on sourcepulse

GitHubView on GitHub
Project Summary

NumPyro is a probabilistic programming library designed for researchers and practitioners who need efficient Bayesian inference. It leverages JAX for automatic differentiation and Just-In-Time (JIT) compilation, enabling accelerated computation on GPUs and TPUs. The library offers a familiar API for users of Pyro and PyTorch, simplifying the transition for those familiar with these ecosystems.

How It Works

NumPyro builds upon JAX's functional programming paradigm and its ability to compile Python and NumPy code into optimized kernels. This allows for significant speedups in complex probabilistic models, particularly for Hamiltonian Monte Carlo (HMC) and its variants like the No-U-Turn Sampler (NUTS). By composing JAX's jit and grad functionalities, NumPyro can compile entire inference steps, including gradient computations and tree-building, into highly efficient XLA-optimized code. It also provides a comprehensive suite of distributions and effect handlers for flexible model specification and custom inference algorithm development.

Quick Start & Requirements

  • Installation:
    • CPU: pip install numpyro
    • GPU: pip install 'numpyro[cuda]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    • Conda: conda install -c conda-forge numpyro
  • Prerequisites: JAX, NumPy. GPU support requires CUDA.
  • Resources: JAX's JIT compilation can significantly reduce CPU overhead. GPU/TPU acceleration is a key feature.
  • Docs: [https://num pyro.readthedocs.io/en/latest/](https://num pyro.readthedocs.io/en/latest/)
  • Examples: https://github.com/pyro-ppl/numpyro/tree/main/examples

Highlighted Details

  • JAX Backend: Leverages JAX for automatic differentiation, JIT compilation, and hardware acceleration (GPU/TPU).
  • MCMC Algorithms: Supports NUTS, HMC, MixedHMC, HMCECS, BarkerMH, HMCGibbs, DiscreteHMCGibbs, and SA.
  • Variational Inference: Includes ADVI with various automatic guides (AutoNormal, AutoDiagonalNormal, AutoMultivariateNormal, AutoLaplaceApproximation, etc.) and support for normalizing flows.
  • Reparameterization: Offers built-in reparameterization strategies (e.g., TransformReparam, LocScaleReparam) to improve MCMC convergence.
  • Enumeration: Supports parallel enumeration for models with discrete latent variables.

Maintenance & Community

  • Contributors: Active development with contributions from researchers in the probabilistic programming community.
  • Community: Forum available for discussions and support.
  • Roadmap: Focus on improving inference robustness, performance tuning, and expanding inference algorithm support.

Licensing & Compatibility

  • License: Apache License 2.0.
  • Compatibility: Permissive license allows for commercial use and integration with closed-source projects.

Limitations & Caveats

NumPyro is under active development, which may lead to API brittleness and bugs. Windows support is limited and may require building JAXlib from source or using the Windows Subsystem for Linux. Models may need to be rewritten in a more functional style to fully leverage JAX's JIT compilation.

Health Check
Last commit

2 weeks ago

Responsiveness

1 day

Pull Requests (30d)
9
Issues (30d)
9
Star History
79 stars in the last 90 days

Explore Similar Projects

Starred by Andrej Karpathy Andrej Karpathy(Founder of Eureka Labs; Formerly at Tesla, OpenAI; Author of CS 231n), Jeff Hammerbacher Jeff Hammerbacher(Cofounder of Cloudera), and
2 more.

gemma_pytorch by google

0.1%
6k
PyTorch implementation for Google's Gemma models
created 1 year ago
updated 2 months ago
Feedback? Help us improve.