jax-ai-stack  by jax-ml

A curated JAX ecosystem for generative AI development

Created 1 year ago
262 stars

Top 97.2% on SourcePulse

GitHubView on GitHub
Project Summary

Summary

The JAX AI Stack provides a unified installation package for a curated collection of JAX-based libraries, mirroring the tools used by Google researchers for developing advanced generative AI models like Imagen and Gemini. It targets engineers and researchers seeking a robust, modular ecosystem for high-performance numerical computing and machine learning, simplifying access to Google's internal development stack.

How It Works

This project leverages JAX, a core Python library for array computations and automatic differentiation, as its foundation. It bundles essential ecosystem packages including Flax for neural network construction, Optax for gradient optimization, Orbax for checkpointing, ml_dtypes for ML-specific numerical types, Chex for reliable code, and Grain for data loading. This modular design promotes innovation by separating domain-specific functionalities while ensuring compatibility through version pinning.

Quick Start & Requirements

  • Primary Install: pip install jax-ai-stack
  • Prerequisites: Python, pip. Optional tensorflow-datasets via jax-ai-stack[tfds].
  • Hardware Support: Install with hardware-specific JAX support, e.g., pip install jax-ai-stack "jax[cuda]" for GPU/CUDA or pip install jax-ai-stack "jax[tpu]" for TPUs. Refer to the JAX installation guide for available options.
  • Documentation Links: The README mentions resources such as "Awesome JAX" for an up-to-date list of projects, "Getting started with JAX" for initial steps, and the "JAX installation" guide for hardware-specific setup.

Highlighted Details

  • Provides a single entry point to Google's internal AI development stack.
  • Includes version-pinned components validated via integration tests for compatibility.
  • Supports development of large-scale generative models.

Maintenance & Community

The project is explicitly noted as a "work-in-progress" with plans for future documentation and tutorials. No specific community channels (e.g., Discord, Slack) or prominent maintainer information are detailed in the provided README.

Licensing & Compatibility

The license type and compatibility notes for commercial use or closed-source linking are not specified in the provided README content.

Limitations & Caveats

The JAX AI Stack is currently a work-in-progress, indicating that documentation and tutorials are incomplete and may be updated significantly. Users should expect ongoing development and potential changes.

Health Check
Last Commit

2 weeks ago

Responsiveness

1 day

Pull Requests (30d)
11
Issues (30d)
0
Star History
11 stars in the last 30 days

Explore Similar Projects

Starred by Andrej Karpathy Andrej Karpathy(Founder of Eureka Labs; Formerly at Tesla, OpenAI; Author of CS 231n), Woosuk Kwon Woosuk Kwon(Coauthor of vLLM), and
15 more.

torchtitan by pytorch

0.6%
5k
PyTorch platform for generative AI model training research
Created 2 years ago
Updated 1 day ago
Feedback? Help us improve.