Probabilistic programming library using JAX for GPU/TPU/CPU
Top 19.1% on sourcepulse
NumPyro is a probabilistic programming library designed for researchers and practitioners who need efficient Bayesian inference. It leverages JAX for automatic differentiation and Just-In-Time (JIT) compilation, enabling accelerated computation on GPUs and TPUs. The library offers a familiar API for users of Pyro and PyTorch, simplifying the transition for those familiar with these ecosystems.
How It Works
NumPyro builds upon JAX's functional programming paradigm and its ability to compile Python and NumPy code into optimized kernels. This allows for significant speedups in complex probabilistic models, particularly for Hamiltonian Monte Carlo (HMC) and its variants like the No-U-Turn Sampler (NUTS). By composing JAX's jit
and grad
functionalities, NumPyro can compile entire inference steps, including gradient computations and tree-building, into highly efficient XLA-optimized code. It also provides a comprehensive suite of distributions and effect handlers for flexible model specification and custom inference algorithm development.
Quick Start & Requirements
pip install numpyro
pip install 'numpyro[cuda]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
conda install -c conda-forge numpyro
Highlighted Details
TransformReparam
, LocScaleReparam
) to improve MCMC convergence.Maintenance & Community
Licensing & Compatibility
Limitations & Caveats
NumPyro is under active development, which may lead to API brittleness and bugs. Windows support is limited and may require building JAXlib from source or using the Windows Subsystem for Linux. Models may need to be rewritten in a more functional style to fully leverage JAX's JIT compilation.
2 weeks ago
1 day