whisper-jax  by sanchit-gandhi

JAX implementation of OpenAI's Whisper model for fast TPU inference

created 2 years ago
4,616 stars

Top 10.9% on sourcepulse

GitHubView on GitHub
Project Summary

This repository provides an optimized JAX implementation of OpenAI's Whisper model, targeting researchers and developers needing high-performance speech-to-text transcription. It offers significant speedups (up to 70x) over PyTorch implementations, particularly on TPUs, making it suitable for large-scale audio processing and real-time applications.

How It Works

Whisper JAX leverages JAX's pmap for efficient data parallelism across multiple accelerators (GPUs/TPUs). The core FlaxWhisperPipline class handles pre/post-processing and JIT compilation for rapid inference. It supports half-precision (float16 or bfloat16) for further speed gains and an optional batching mechanism that chunks audio for parallel transcription, achieving a 10x speedup with minimal WER impact.

Quick Start & Requirements

  • Install: pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
  • Prerequisites: JAX (version 0.4.5 tested), Python 3.9.
  • Resources: Compatible with CPU, GPU, and TPU. Kaggle notebook demonstrates 30-minute transcription in ~30 seconds on Cloud TPU.
  • Demo: Hugging Face Hub

Highlighted Details

  • Up to 70x faster than PyTorch implementations on TPU.
  • Supports batching for 10x speedup.
  • Half-precision computation (float16, bfloat16) for increased performance.
  • Timestamp prediction and speech translation capabilities.
  • Integrates with T5x for advanced model/data parallelism.

Maintenance & Community

  • Based on 🤗 Hugging Face Transformers.
  • Acknowledgements mention contributors from Hugging Face and Gradio.
  • Code is available for creating standalone inference endpoints.

Licensing & Compatibility

  • The repository itself does not explicitly state a license in the README. However, it is built upon Hugging Face Transformers and uses JAX, which are typically permissive. Compatibility for commercial use is likely, but requires verification of underlying dependencies.

Limitations & Caveats

  • Requires installation of JAX, which can have specific hardware/driver dependencies.
  • Advanced partitioning requires familiarity with T5x conventions.
Health Check
Last commit

1 year ago

Responsiveness

Inactive

Pull Requests (30d)
0
Issues (30d)
1
Star History
43 stars in the last 90 days

Explore Similar Projects

Feedback? Help us improve.