JAX module for symbolic computation
Top 81.4% on sourcepulse
This library converts SymPy expressions into trainable JAX modules, enabling gradient-based optimization of symbolic mathematical formulas. It is designed for researchers and engineers working with symbolic computation who want to leverage JAX's automatic differentiation and hardware acceleration capabilities for parameter tuning.
How It Works
The core approach involves transforming a PyTree of SymPy expressions into an Equinox module. SymPy symbols are mapped to module inputs, while SymPy numbers (integers, rationals, floats) are treated as trainable parameters (leaves in the PyTree). This allows users to define complex symbolic functions and then optimize their numerical coefficients using standard JAX/Equinox workflows.
Quick Start & Requirements
pip install sympy2jax
Highlighted Details
Maintenance & Community
Licensing & Compatibility
Limitations & Caveats
The library is presented as "super easy" with minimal documentation, suggesting it may be best suited for straightforward use cases. The lack of explicit licensing information is a significant caveat for adoption.
3 months ago
1 day