jax-llm-examples  by jax-ml

High-performance LLM implementations in JAX

Created 1 year ago
251 stars

Top 99.9% on SourcePulse

GitHubView on GitHub
Project Summary

Summary

This repository, jax-ml/jax-llm-examples, provides a curated collection of high-performance large language model (LLM) implementations developed purely in JAX. It aims to serve engineers, researchers, and power users by offering minimal yet efficient code examples for various state-of-the-art LLMs, facilitating understanding, experimentation, and adaptation. The primary benefit lies in showcasing performant LLM implementations within the JAX ecosystem.

How It Works

The core approach leverages JAX's capabilities for automatic differentiation and efficient compilation to hardware accelerators, enabling high-performance LLM execution. By focusing on "minimal" implementations, the project prioritizes clarity and directness, allowing users to grasp the essential components of each model without excessive abstraction. This design choice is advantageous for learning and for building custom solutions upon a solid, performant foundation.

Quick Start & Requirements

While specific installation commands are not detailed in this snippet, the README directs users to multi_host_README.md and a tpu_toolkit.sh script for guidance on multi-host cluster setup and distributed training. This suggests that the examples are geared towards distributed computing environments, potentially requiring access to clusters and TPUs for optimal use.

Highlighted Details

  • Features implementations for a diverse range of LLMs, including DeepSeek R1, Llama 4, Llama 3, Qwen 3, Kimi K2, OpenAI GPT OSS, NVIDIA Nemotron 3 Nano, and Gemma 4.
  • Emphasizes "minimal yet performant" code, focusing on essential LLM components.
  • Provides specific resources (multi_host_README.md, tpu_toolkit.sh) for advanced multi-host cluster setup and distributed training.

Maintenance & Community

No information regarding contributors, sponsorships, community channels (like Discord/Slack), or roadmaps is present in the provided README snippet.

Licensing & Compatibility

The README snippet does not specify the project's license type or any compatibility notes relevant to commercial use or integration with closed-source projects.

Limitations & Caveats

The project is explicitly described as "in progress," indicating that it is under active development and may not represent a stable, production-ready release. The emphasis on multi-host cluster setup and distributed training suggests that the examples are primarily designed for large-scale, distributed environments, and may require significant adaptation for single-machine or smaller-scale deployments.

Health Check
Last Commit

2 weeks ago

Responsiveness

Inactive

Pull Requests (30d)
3
Issues (30d)
1
Star History
5 stars in the last 30 days

Explore Similar Projects

Starred by Jeff Hammerbacher Jeff Hammerbacher(Cofounder of Cloudera), Edward Sun Edward Sun(Research Scientist at Meta Superintelligence Lab), and
1 more.

jaxformer by salesforce

0%
299
JAX library for LLM training on TPUs
Created 3 years ago
Updated 2 years ago
Starred by Matthew Johnson Matthew Johnson(Coauthor of JAX; Research Scientist at Google Brain), Roy Frostig Roy Frostig(Coauthor of JAX; Research Scientist at Google DeepMind), and
3 more.

sglang-jax by sgl-project

1.1%
269
High-performance LLM inference engine for JAX/TPU serving
Created 9 months ago
Updated 5 hours ago
Starred by Pawel Garbacki Pawel Garbacki(Cofounder of Fireworks AI) and Yineng Zhang Yineng Zhang(Inference Lead at SGLang; Research Scientist at Together AI).

aiconfigurator by ai-dynamo

2.6%
275
LLM serving configuration optimization
Created 9 months ago
Updated 3 days ago
Starred by Jiayi Pan Jiayi Pan(Author of SWE-Gym; MTS at xAI), Chip Huyen Chip Huyen(Author of "AI Engineering", "Designing Machine Learning Systems"), and
12 more.

EasyLM by young-geng

0.0%
3k
LLM training/finetuning framework in JAX/Flax
Created 3 years ago
Updated 1 year ago
Feedback? Help us improve.