jax-js  by ekzhang

JAX-style ML framework for the web

Created 1 year ago
723 stars

Top 47.6% on SourcePulse

GitHubView on GitHub
Project Summary

Summary

jax-js brings JAX-style, high-performance numerical computation and machine learning capabilities directly to the web browser. It targets developers and researchers needing to run complex ML models or numerical simulations client-side, offering a portable solution that leverages WebGPU and WebAssembly for speed, with a familiar NumPy/JAX API.

How It Works

This library compiles array operations into an intermediate representation, subsequently synthesizing optimized kernels for WebAssembly (CPU) and WebGPU (GPU). Built entirely from scratch with zero external dependencies, jax-js prioritizes API compatibility with NumPy and JAX. Its client-side execution model makes it exceptionally portable, running wherever a modern browser is available.

Quick Start & Requirements

Installation is straightforward via npm: npm i @jax-js/jax. Usage involves importing numpy as np from the library. For optimal performance, a browser supporting WebGPU is recommended; otherwise, it falls back to Wasm. Official resources include the Website, API Reference, Compatibility Table, and Discord.

Highlighted Details

  • Performance: Achieves over 7000 GFLOP/s for matrix multiplication on high-end hardware via its WebGPU backend.
  • JAX Features: Implements core JAX transformations including grad (autodiff), vmap (vectorization), and jit (kernel fusion) for performance optimization.
  • Backends: Supports WebGPU for high-performance GPU acceleration, WebGL2 as a fallback, and WebAssembly for CPU execution.
  • API & Data Types: Offers close API compatibility with NumPy/JAX, supporting Float32, Float64, and Float16 (partial).
  • Bundle Size: A lean 80 KB (gzipped) bundle size.

Maintenance & Community

The project maintains a Discord server for community interaction. Development utilizes pnpm and Vitest for testing, with specific notes on Playwright configuration for WebGPU headless testing. Future work includes expanding JAX function support, enhancing runtime performance, and developing a fast transformer inference engine.

Licensing & Compatibility

The project's license is not explicitly stated in the README, which is a critical omission for assessing commercial use or derivative works.

Limitations & Caveats

The WebAssembly backend is currently single-threaded and blocking. While WebGL2 is supported, it's offered on a best-effort basis and is significantly slower than WebGPU. Several advanced numerical operations and data types (e.g., SVD, BFloat16) are not yet implemented or have partial support. Memory management relies on explicit reference counting, requiring careful handling by the developer to prevent leaks.

Health Check
Last Commit

6 days ago

Responsiveness

Inactive

Pull Requests (30d)
9
Issues (30d)
11
Star History
48 stars in the last 30 days

Explore Similar Projects

Starred by Elvis Saravia Elvis Saravia(Founder of DAIR.AI), Roy Frostig Roy Frostig(Coauthor of JAX; Research Scientist at Google DeepMind), and
8 more.

numpyro by pyro-ppl

0.1%
3k
Probabilistic programming library using JAX for GPU/TPU/CPU
Created 7 years ago
Updated 3 days ago
Feedback? Help us improve.