JAX research toolkit for neural network building, editing, and visualization
Top 24.4% on sourcepulse
Penzai is a JAX research toolkit designed for building, visualizing, and manipulating neural networks as functional pytree data structures. It targets researchers and engineers focused on model interpretability, ablation studies, and debugging, offering a declarative approach to model construction and modification.
How It Works
Penzai leverages JAX's pytree capabilities to represent neural networks as data structures, enabling programmatic inspection and modification. Key components include Treescope for enhanced pretty-printing of nested data, penzai.core.selectors
for powerful pytree traversal and manipulation, and penzai.core.named_axes
for a flexible named axis system that integrates seamlessly with JAX. The penzai.nn
module provides a combinator-based API for defining models, supporting mutable state and parameter sharing at leaves.
Quick Start & Requirements
pip install penzai
treescope.basic_interactive_setup(autovisualize_arrays=True)
Highlighted Details
penzai.core.selectors
offers generalized .at[...].set()
functionality for complex pytree rewrites.penzai.core.named_axes
allows seamless switching between named and positional array indexing.Maintenance & Community
Licensing & Compatibility
Limitations & Caveats
Penzai v0.2 introduced breaking changes to the neural network API, with a V2 API focusing on mutable state and parameter sharing. Users of the V1 API can access older behavior via penzai.deprecated.v1
.
1 month ago
1 week