mamba.py  by alxndrTL

PyTorch/MLX library for Mamba sequence modeling

Created 1 year ago
1,317 stars

Top 30.4% on SourcePulse

GitHubView on GitHub
Project Summary

This repository provides a simple and efficient implementation of the Mamba architecture, a state-space model designed for efficient sequence modeling, in both PyTorch and MLX. It aims to be educational, offering a readable codebase that achieves good training performance, and includes extensions like Jamba and Vision Mamba.

How It Works

The core of the implementation relies on a parallel scan (Blelloch scan) algorithm for the selective scan operation, which significantly speeds up computation over sequential approaches by parallelizing across the time dimension. This approach is key to Mamba's linear time complexity for sequence length. The PyTorch version uses custom CUDA kernels for performance, while the MLX version targets Apple Silicon hardware.

Quick Start & Requirements

  • Install via pip: pip install mambapy
  • PyTorch version requires CUDA for optimal performance. MLX version is suitable for Mac users.
  • Official documentation and examples are available within the repository.

Highlighted Details

  • Implements Mamba, Mamba-2, and Jamba (Mamba + attention) architectures.
  • Includes Vision Mamba (ViM) and muP (model-parallelism) for hyperparameter transfer.
  • Offers both PyTorch and MLX backends for broader hardware compatibility.
  • Supports loading pretrained models from Hugging Face.

Maintenance & Community

  • The project has been integrated into the Hugging Face Transformers library.
  • Active development is indicated by recent updates and PRs.
  • Citation information is provided.

Licensing & Compatibility

  • The repository does not explicitly state a license in the README. This requires further investigation for commercial use or closed-source linking.

Limitations & Caveats

  • The PyTorch implementation does not use recomputation in the backward pass, potentially leading to higher memory requirements during training compared to the official Mamba implementation.
  • Performance comparisons indicate mamba.py is approximately 2x slower than the official CUDA implementation, particularly as d_state increases.
  • torch.compile is not currently compatible due to issues with the custom PScan autograd function.
Health Check
Last Commit

9 months ago

Responsiveness

Inactive

Pull Requests (30d)
0
Issues (30d)
0
Star History
16 stars in the last 30 days

Explore Similar Projects

Starred by Albert Gu Albert Gu(Cofounder of Cartesia; Professor at CMU) and Binyuan Hui Binyuan Hui(Research Scientist at Alibaba Qwen).

Awesome-Mamba-Papers by yyyujintang

0.1%
1k
Mamba papers collection
Created 1 year ago
Updated 11 months ago
Starred by Jeff Hammerbacher Jeff Hammerbacher(Cofounder of Cloudera), Stas Bekman Stas Bekman(Author of "Machine Learning Engineering Open Book"; Research Engineer at Snowflake), and
25 more.

gpt-neox by EleutherAI

0.2%
7k
Framework for training large-scale autoregressive language models
Created 4 years ago
Updated 2 days ago
Starred by George Hotz George Hotz(Author of tinygrad; Founder of the tiny corp, comma.ai), Alex Chen Alex Chen(Cofounder of Nexa AI), and
25 more.

mamba by state-spaces

0.4%
16k
Mamba SSM architecture for sequence modeling
Created 1 year ago
Updated 2 weeks ago
Feedback? Help us improve.