JAX toolbox provides CI, Docker images, and examples for JAX development on NVIDIA GPUs
Top 84.6% on sourcepulse
NVIDIA/JAX-Toolbox provides pre-built Docker images and optimized examples for popular JAX libraries, targeting researchers and engineers working with large language models and other deep learning tasks on NVIDIA GPUs. It simplifies setup and enhances performance by bundling optimized JAX frameworks and providing curated XLA/NCCL flags.
How It Works
The toolbox offers a collection of Docker images, each tailored for specific JAX frameworks like MaxText, T5X, and Levanter, supporting various model architectures. These images are pre-configured with performance-enhancing XLA flags and NCCL settings, aiming to maximize GPU utilization and reduce communication overhead. NVIDIA maintains a public CI pipeline for these containers and conducts internal testing on high-end NVIDIA hardware.
Quick Start & Requirements
docker pull ghcr.io/nvidia/jax:XXX
or docker run -it --shm-size=1g ...
to address potential bus error
issues.Highlighted Details
Maintenance & Community
The project is maintained by NVIDIA. Links to community resources like Discord/Slack are not explicitly provided in the README.
Licensing & Compatibility
The README does not explicitly state a license. Compatibility for commercial use or closed-source linking is not specified.
Limitations & Caveats
The NCCL_NVLS_ENABLE
environment variable is currently set to 0
(disabling NVLink SHARP), with plans to re-enable it in future releases. Users encountering issues with multi-arch images in enroot/pyxis may need to upgrade enroot to v3.4.0 or later.
19 hours ago
1 day