Typing library for array shapes/dtypes
Top 28.2% on sourcepulse
jaxtyping
provides type annotations and runtime checking for the shape and dtype of arrays from JAX, PyTorch, NumPy, MLX, and TensorFlow, as well as PyTrees. It enables developers to enforce array dimensions and data types at runtime, improving code robustness and maintainability for machine learning and scientific computing tasks.
How It Works
The library uses a novel annotation system where array types are specified with a base type (e.g., Float
, Int
) followed by shape and dtype constraints within square brackets (e.g., Float[Array, "batch height width"]
). This allows for precise specification of array structures, which can then be validated at runtime by compatible type-checking libraries like typeguard
or beartype
. This approach offers a declarative way to ensure array integrity without manual checks.
Quick Start & Requirements
pip install jaxtyping
typeguard
or beartype
for runtime checking.Highlighted Details
Maintenance & Community
Licensing & Compatibility
Limitations & Caveats
The library is primarily focused on static analysis and runtime checking of array shapes and dtypes; it does not perform numerical computations itself. JAX-specific types are unavailable if JAX is not installed.
3 months ago
1 day