JAX implementation of OpenAI's Whisper model for fast TPU inference
Top 10.9% on sourcepulse
This repository provides an optimized JAX implementation of OpenAI's Whisper model, targeting researchers and developers needing high-performance speech-to-text transcription. It offers significant speedups (up to 70x) over PyTorch implementations, particularly on TPUs, making it suitable for large-scale audio processing and real-time applications.
How It Works
Whisper JAX leverages JAX's pmap
for efficient data parallelism across multiple accelerators (GPUs/TPUs). The core FlaxWhisperPipline
class handles pre/post-processing and JIT compilation for rapid inference. It supports half-precision (float16
or bfloat16
) for further speed gains and an optional batching mechanism that chunks audio for parallel transcription, achieving a 10x speedup with minimal WER impact.
Quick Start & Requirements
pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
Highlighted Details
float16
, bfloat16
) for increased performance.Maintenance & Community
Licensing & Compatibility
Limitations & Caveats
1 year ago
Inactive