JAX-Toolbox  by NVIDIA

JAX toolbox provides CI, Docker images, and examples for JAX development on NVIDIA GPUs

created 2 years ago
327 stars

Top 84.6% on sourcepulse

GitHubView on GitHub
1 Expert Loves This Project
Project Summary

NVIDIA/JAX-Toolbox provides pre-built Docker images and optimized examples for popular JAX libraries, targeting researchers and engineers working with large language models and other deep learning tasks on NVIDIA GPUs. It simplifies setup and enhances performance by bundling optimized JAX frameworks and providing curated XLA/NCCL flags.

How It Works

The toolbox offers a collection of Docker images, each tailored for specific JAX frameworks like MaxText, T5X, and Levanter, supporting various model architectures. These images are pre-configured with performance-enhancing XLA flags and NCCL settings, aiming to maximize GPU utilization and reduce communication overhead. NVIDIA maintains a public CI pipeline for these containers and conducts internal testing on high-end NVIDIA hardware.

Quick Start & Requirements

  • Install/Run: Use docker pull ghcr.io/nvidia/jax:XXX or docker run -it --shm-size=1g ... to address potential bus error issues.
  • Prerequisites: NVIDIA GPU, Docker, and potentially NVIDIA Container Toolkit (enroot/pyxis for Slurm).
  • Resources: Requires sufficient GPU memory and disk space for Docker images.
  • Links: NVIDIA NGC Container, JAX on Public Clouds, Profiling.

Highlighted Details

  • Supports popular LLM frameworks: MaxText, T5X, Levanter, Axlearn.
  • Includes optimized XLA flags for latency hiding and Triton GEMM.
  • Provides specific container tags for stable, dated builds.
  • Offers integration guides for major cloud providers (AWS, GCP, Azure).

Maintenance & Community

The project is maintained by NVIDIA. Links to community resources like Discord/Slack are not explicitly provided in the README.

Licensing & Compatibility

The README does not explicitly state a license. Compatibility for commercial use or closed-source linking is not specified.

Limitations & Caveats

The NCCL_NVLS_ENABLE environment variable is currently set to 0 (disabling NVLink SHARP), with plans to re-enable it in future releases. Users encountering issues with multi-arch images in enroot/pyxis may need to upgrade enroot to v3.4.0 or later.

Health Check
Last commit

19 hours ago

Responsiveness

1 day

Pull Requests (30d)
63
Issues (30d)
0
Star History
30 stars in the last 90 days

Explore Similar Projects

Starred by Carol Willing Carol Willing(Core Contributor to CPython, Jupyter), Chip Huyen Chip Huyen(Author of AI Engineering, Designing Machine Learning Systems), and
4 more.

dynamo by ai-dynamo

1.1%
5k
Inference framework for distributed generative AI model serving
created 5 months ago
updated 19 hours ago
Feedback? Help us improve.