jax-triton  by jax-ml

JAX integrations for OpenAI Triton

Created 3 years ago
460 stars

Top 65.2% on SourcePulse

GitHubView on GitHub
Project Summary

This project provides integrations between JAX and OpenAI Triton, enabling users to write custom, high-performance kernels for JAX arrays. It targets researchers and engineers needing to optimize complex computations beyond standard JAX operations, offering significant speedups through custom GPU kernels.

How It Works

The core of jax-triton is the jax_triton.triton_call function, which allows seamless integration of Triton kernels within JAX workflows, including jax.jit. This enables users to define custom kernels using Triton's Pythonic API and apply them directly to JAX arrays, leveraging Triton's low-level control over GPU hardware for performance.

Quick Start & Requirements

Highlighted Details

  • Enables writing custom GPU kernels using Triton's domain-specific language.
  • Supports integration within jax.jit-compiled functions for seamless performance optimization.
  • Provides examples for advanced use cases like fused attention.

Maintenance & Community

  • This is not an officially supported Google product.
  • Development can be done via editable install after cloning the repository.
  • Tests are available via pytest.

Licensing & Compatibility

  • The repository does not explicitly state a license in the README.

Limitations & Caveats

The project requires building Triton from source, which can add complexity to the setup process. Compatibility with specific CUDA versions or JAX builds may require manual verification.

Health Check
Last Commit

1 month ago

Responsiveness

1 day

Pull Requests (30d)
5
Issues (30d)
0
Star History
13 stars in the last 30 days

Explore Similar Projects

Starred by Luis Capelo Luis Capelo(Cofounder of Lightning AI), Sasha Rush Sasha Rush(Research Scientist at Cursor; Professor at Cornell Tech), and
3 more.

jax-js by ekzhang

0.3%
811
JAX-style ML framework for the web
Created 1 year ago
Updated 15 hours ago
Starred by Matthew Johnson Matthew Johnson(Coauthor of JAX; Research Scientist at Google Brain), Roy Frostig Roy Frostig(Coauthor of JAX; Research Scientist at Google DeepMind), and
3 more.

sglang-jax by sgl-project

0.7%
275
High-performance LLM inference engine for JAX/TPU serving
Created 10 months ago
Updated 9 hours ago
Starred by Elvis Saravia Elvis Saravia(Founder of DAIR.AI), Roy Frostig Roy Frostig(Coauthor of JAX; Research Scientist at Google DeepMind), and
8 more.

numpyro by pyro-ppl

0.2%
3k
Probabilistic programming library using JAX for GPU/TPU/CPU
Created 7 years ago
Updated 2 days ago
Feedback? Help us improve.