flash-sparse-attention  by HKUSTDial

Trainable sparse attention for long sequences

Created 11 months ago
607 stars

Top 53.9% on SourcePulse

GitHubView on GitHub
Project Summary

This project provides a high-performance, trainable sparse attention implementation designed to significantly improve memory efficiency and computational speed for Transformer models handling extremely long sequences. It targets researchers and engineers working with large language models and other sequence-based deep learning architectures, offering a way to scale attention mechanisms beyond current practical limits.

How It Works

Flash-Sparse-Attention combines the memory efficiency of Flash Attention with sparse computation techniques. It implements trainable sparse attention, allowing the model to dynamically skip low-contribution attention weights via a configurable softmax_threshold. The core approach supports dense, sparse, and gated attention variants, regular and variable-length inputs, causal and local window attention, and optimizations like Split-KV for decoding. This allows for reduced effective compute and memory usage on long sequences.

Quick Start & Requirements

  • Installation:
    • From PyPI: pip install flash-sparse-attn
    • From source: git clone https://github.com/flash-algo/flash-sparse-attn.git && cd flash-sparse-attn && pip install .
  • Prerequisites: Linux (Ubuntu 22.04+), NVIDIA GPU (Compute Capability 8.0+), Python 3.9+, PyTorch 2.5.1+. Triton is installed automatically.
  • Links: Repository: https://github.com/flash-algo/flash-sparse-attn.git

Highlighted Details

  • Supports Grouped Query Attention (GQA) and Multi-Query Attention (MQA).
  • Includes optimizations for decoding workloads via a Split-KV path.
  • Performance benchmarks are available for forward, backward, and decoding, comparing against FlashAttention.
  • Benchmarking utilizes Qwen model attention projections and the Needle-in-a-Haystack dataset for realistic LLM workloads.

Maintenance & Community

The project acknowledges contributions from OpenSeek, Flash-Attention, and NVIDIA CUTLASS. No specific community channels (like Discord/Slack) or roadmap links are provided in the README. The citation points to an arXiv paper from 2025, indicating recent development activity.

Licensing & Compatibility

The README does not explicitly state the project's license. This omission requires further investigation for compatibility with commercial or closed-source applications.

Limitations & Caveats

Support for arbitrary mask and bias shapes is available in a separate branch, not the main branch. Features such as Paged Attention, TMA, WGMMA, and FP8 low precision are listed as future aims, indicating they are not yet implemented.

Health Check
Last Commit

4 days ago

Responsiveness

Inactive

Pull Requests (30d)
22
Issues (30d)
0
Star History
57 stars in the last 30 days

Explore Similar Projects

Starred by Alex Yu Alex Yu(Research Scientist at OpenAI; Cofounder of Luma AI), Yineng Zhang Yineng Zhang(Inference Lead at SGLang; Research Scientist at Together AI), and
1 more.

ring-attention-pytorch by lucidrains

0%
549
Pytorch impl of Ring Attention for near-infinite context
Created 2 years ago
Updated 11 months ago
Starred by Mehdi Amini Mehdi Amini(Author of MLIR; Distinguished Engineer at NVIDIA), Chip Huyen Chip Huyen(Author of "AI Engineering", "Designing Machine Learning Systems"), and
15 more.

flashinfer by flashinfer-ai

1.6%
5k
Kernel library for LLM serving
Created 2 years ago
Updated 1 day ago
Feedback? Help us improve.