Unified sequence parallel attention for long context LLM training/inference
Top 59.9% on sourcepulse
This repository provides Unified Sequence Parallelism (USP), a novel attention mechanism designed to enable efficient training and inference of Large Language Models (LLMs) with long contexts. It addresses limitations of existing methods like DeepSpeed-Ulysses and Ring-Attention by synergizing their strengths, offering improved versatility and performance for researchers and engineers working with extended sequence lengths.
How It Works
USP combines DeepSpeed-Ulysses-Attention and Ring-Attention to overcome individual drawbacks. Ulysses is limited by head count and compatibility with Tensor Parallelism, while Ring-Attention is less efficient and prone to deadlocks. USP offers a unified approach, allowing for flexible configuration (e.g., "zigzag" or "stripe" for load balancing) and supporting various hardware backends, including those without FlashAttention via a PyTorch implementation.
Quick Start & Requirements
pip install yunchang
(requires flash-attn
2.6.x or 2.7.x for GPU acceleration). FlashAttention V3 requires installation from source. A PyTorch-based implementation is available for hardware without FlashAttention (attn_type=AttnType.TORCH
), though backward pass is not supported.install_amd.md
.set_seq_parallel_pg
and using LongContextAttention
as a drop-in replacement for standard attention layers. Examples and testing scripts are provided in the test/
directory.Highlighted Details
ring_impl_type
("zigzag", "stripe", "basic") and attn_type
(FA, FA3, TORCH).Maintenance & Community
The project has been integrated into several notable projects including NVIDIA/TransformerEngine, xdit-project/xDiT, and NVlabs/VILA, indicating active adoption and validation.
Licensing & Compatibility
The repository is released under the MIT License, permitting commercial use and integration with closed-source projects.
Limitations & Caveats
The PyTorch-based attention implementation (AttnType.TORCH
) does not support the backward pass. The "zigzag" and "stripe" implementations have specific sequence dimension layout requirements.
2 weeks ago
1 week