JAX-Toolbox  by NVIDIA

JAX toolbox provides CI, Docker images, and examples for JAX development on NVIDIA GPUs

Created 2 years ago
335 stars

Top 82.0% on SourcePulse

GitHubView on GitHub
Project Summary

NVIDIA/JAX-Toolbox provides pre-built Docker images and optimized examples for popular JAX libraries, targeting researchers and engineers working with large language models and other deep learning tasks on NVIDIA GPUs. It simplifies setup and enhances performance by bundling optimized JAX frameworks and providing curated XLA/NCCL flags.

How It Works

The toolbox offers a collection of Docker images, each tailored for specific JAX frameworks like MaxText, T5X, and Levanter, supporting various model architectures. These images are pre-configured with performance-enhancing XLA flags and NCCL settings, aiming to maximize GPU utilization and reduce communication overhead. NVIDIA maintains a public CI pipeline for these containers and conducts internal testing on high-end NVIDIA hardware.

Quick Start & Requirements

  • Install/Run: Use docker pull ghcr.io/nvidia/jax:XXX or docker run -it --shm-size=1g ... to address potential bus error issues.
  • Prerequisites: NVIDIA GPU, Docker, and potentially NVIDIA Container Toolkit (enroot/pyxis for Slurm).
  • Resources: Requires sufficient GPU memory and disk space for Docker images.
  • Links: NVIDIA NGC Container, JAX on Public Clouds, Profiling.

Highlighted Details

  • Supports popular LLM frameworks: MaxText, T5X, Levanter, Axlearn.
  • Includes optimized XLA flags for latency hiding and Triton GEMM.
  • Provides specific container tags for stable, dated builds.
  • Offers integration guides for major cloud providers (AWS, GCP, Azure).

Maintenance & Community

The project is maintained by NVIDIA. Links to community resources like Discord/Slack are not explicitly provided in the README.

Licensing & Compatibility

The README does not explicitly state a license. Compatibility for commercial use or closed-source linking is not specified.

Limitations & Caveats

The NCCL_NVLS_ENABLE environment variable is currently set to 0 (disabling NVLink SHARP), with plans to re-enable it in future releases. Users encountering issues with multi-arch images in enroot/pyxis may need to upgrade enroot to v3.4.0 or later.

Health Check
Last Commit

1 day ago

Responsiveness

1 day

Pull Requests (30d)
49
Issues (30d)
1
Star History
10 stars in the last 30 days

Explore Similar Projects

Starred by Luis Capelo Luis Capelo(Cofounder of Lightning AI), Patrick von Platen Patrick von Platen(Author of Hugging Face Diffusers; Research Engineer at Mistral), and
4 more.

ktransformers by kvcache-ai

0.3%
15k
Framework for LLM inference optimization experimentation
Created 1 year ago
Updated 2 days ago
Feedback? Help us improve.