PyTorch/MLX library for Mamba sequence modeling
Top 31.4% on sourcepulse
This repository provides a simple and efficient implementation of the Mamba architecture, a state-space model designed for efficient sequence modeling, in both PyTorch and MLX. It aims to be educational, offering a readable codebase that achieves good training performance, and includes extensions like Jamba and Vision Mamba.
How It Works
The core of the implementation relies on a parallel scan (Blelloch scan) algorithm for the selective scan operation, which significantly speeds up computation over sequential approaches by parallelizing across the time dimension. This approach is key to Mamba's linear time complexity for sequence length. The PyTorch version uses custom CUDA kernels for performance, while the MLX version targets Apple Silicon hardware.
Quick Start & Requirements
pip install mambapy
Highlighted Details
Maintenance & Community
Licensing & Compatibility
Limitations & Caveats
d_state
increases.torch.compile
is not currently compatible due to issues with the custom PScan autograd function.8 months ago
1 day