Pytorch impl of Ring Attention for near-infinite context
Top 60.3% on sourcepulse
This repository provides PyTorch implementations of Ring Attention and Striped Attention, designed to enable transformers to process significantly longer sequences (millions of tokens) by sharding attention computations across multiple GPUs. It also includes Grouped Query Attention for further communication cost reduction, targeting researchers and engineers working with large-scale language models.
How It Works
Ring Attention splits the sequence dimension across GPUs, processing attention tiles in a ring-like fashion to minimize communication overhead. Striped Attention builds on this by permuting the sequence for improved workload balancing in autoregressive transformers. The implementation leverages Flash Attention for efficiency and Triton for CUDA kernels, optimizing both forward and backward passes.
Quick Start & Requirements
pip install ring-attention-pytorch
pip install -r requirements.txt
then run python assert.py
or python assert_tree_attn.py
.Highlighted Details
Maintenance & Community
The project is sponsored by the A16Z Open Source AI Grant Program. It acknowledges contributions from Tri Dao (Flash Attention) and Phil Tillet (Triton).
Licensing & Compatibility
The repository does not explicitly state a license in the README.
Limitations & Caveats
The "Todo" list indicates ongoing development, with several features and optimizations still pending, including distributed PyTorch testing and specific dataset sharding strategies for training. Some CUDA kernel optimizations are noted as "hacks" or requiring further validation.
2 months ago
1 day