jax-triton  by jax-ml

JAX integrations for OpenAI Triton

Created 3 years ago
422 stars

Top 69.8% on SourcePulse

GitHubView on GitHub
Project Summary

This project provides integrations between JAX and OpenAI Triton, enabling users to write custom, high-performance kernels for JAX arrays. It targets researchers and engineers needing to optimize complex computations beyond standard JAX operations, offering significant speedups through custom GPU kernels.

How It Works

The core of jax-triton is the jax_triton.triton_call function, which allows seamless integration of Triton kernels within JAX workflows, including jax.jit. This enables users to define custom kernels using Triton's Pythonic API and apply them directly to JAX arrays, leveraging Triton's low-level control over GPU hardware for performance.

Quick Start & Requirements

Highlighted Details

  • Enables writing custom GPU kernels using Triton's domain-specific language.
  • Supports integration within jax.jit-compiled functions for seamless performance optimization.
  • Provides examples for advanced use cases like fused attention.

Maintenance & Community

  • This is not an officially supported Google product.
  • Development can be done via editable install after cloning the repository.
  • Tests are available via pytest.

Licensing & Compatibility

  • The repository does not explicitly state a license in the README.

Limitations & Caveats

The project requires building Triton from source, which can add complexity to the setup process. Compatibility with specific CUDA versions or JAX builds may require manual verification.

Health Check
Last Commit

2 weeks ago

Responsiveness

1 day

Pull Requests (30d)
3
Issues (30d)
0
Star History
9 stars in the last 30 days

Explore Similar Projects

Starred by Roy Frostig Roy Frostig(Coauthor of JAX; Research Scientist at Google DeepMind), Jiayi Pan Jiayi Pan(Author of SWE-Gym; MTS at xAI), and
2 more.

HighPerfLLMs2024 by rwitten

0.6%
534
Jax course for high-performance LLM construction
Created 1 year ago
Updated 1 year ago
Starred by Elvis Saravia Elvis Saravia(Founder of DAIR.AI), Roy Frostig Roy Frostig(Coauthor of JAX; Research Scientist at Google DeepMind), and
8 more.

numpyro by pyro-ppl

0.1%
3k
Probabilistic programming library using JAX for GPU/TPU/CPU
Created 6 years ago
Updated 1 week ago
Feedback? Help us improve.