mamba.py  by alxndrTL

PyTorch/MLX library for Mamba sequence modeling

created 1 year ago
1,296 stars

Top 31.4% on sourcepulse

GitHubView on GitHub
Project Summary

This repository provides a simple and efficient implementation of the Mamba architecture, a state-space model designed for efficient sequence modeling, in both PyTorch and MLX. It aims to be educational, offering a readable codebase that achieves good training performance, and includes extensions like Jamba and Vision Mamba.

How It Works

The core of the implementation relies on a parallel scan (Blelloch scan) algorithm for the selective scan operation, which significantly speeds up computation over sequential approaches by parallelizing across the time dimension. This approach is key to Mamba's linear time complexity for sequence length. The PyTorch version uses custom CUDA kernels for performance, while the MLX version targets Apple Silicon hardware.

Quick Start & Requirements

  • Install via pip: pip install mambapy
  • PyTorch version requires CUDA for optimal performance. MLX version is suitable for Mac users.
  • Official documentation and examples are available within the repository.

Highlighted Details

  • Implements Mamba, Mamba-2, and Jamba (Mamba + attention) architectures.
  • Includes Vision Mamba (ViM) and muP (model-parallelism) for hyperparameter transfer.
  • Offers both PyTorch and MLX backends for broader hardware compatibility.
  • Supports loading pretrained models from Hugging Face.

Maintenance & Community

  • The project has been integrated into the Hugging Face Transformers library.
  • Active development is indicated by recent updates and PRs.
  • Citation information is provided.

Licensing & Compatibility

  • The repository does not explicitly state a license in the README. This requires further investigation for commercial use or closed-source linking.

Limitations & Caveats

  • The PyTorch implementation does not use recomputation in the backward pass, potentially leading to higher memory requirements during training compared to the official Mamba implementation.
  • Performance comparisons indicate mamba.py is approximately 2x slower than the official CUDA implementation, particularly as d_state increases.
  • torch.compile is not currently compatible due to issues with the custom PScan autograd function.
Health Check
Last commit

8 months ago

Responsiveness

1 day

Pull Requests (30d)
0
Issues (30d)
0
Star History
88 stars in the last 90 days

Explore Similar Projects

Starred by George Hotz George Hotz(Author of tinygrad; Founder of the tiny corp, comma.ai), Shawn Wang Shawn Wang(Editor of Latent Space), and
9 more.

mamba by state-spaces

0.4%
16k
Mamba SSM architecture for sequence modeling
created 1 year ago
updated 2 weeks ago
Feedback? Help us improve.