ringattention  by haoliuhl

Jax implementation of RingAttention for large context models (research paper)

created 2 years ago
719 stars

Top 48.8% on sourcepulse

GitHubView on GitHub
1 Expert Loves This Project
Project Summary

This repository provides a Jax implementation of RingAttention, a technique designed to enable training of transformer models with significantly larger context windows, potentially near-infinite. It targets researchers and engineers working on large context models, offering a method to overcome memory limitations by distributing attention and feedforward computations across multiple devices.

How It Works

RingAttention distributes attention and feedforward computations across devices in a ring-like fashion, overlapping communication with computation. This blockwise parallelization allows for training with context lengths that are orders of magnitude larger than standard methods, without incurring additional computational or communication overhead. The implementation leverages Jax's shard_map for efficient parallel execution.

Quick Start & Requirements

  • Install: pip install ringattention
  • Prerequisites: Jax, compatible hardware for distributed training (multiple GPUs/TPUs).
  • Example usage and detailed configuration options are available in the README.

Highlighted Details

  • Enables training with context lengths proportional to the number of devices.
  • Achieves near-infinite context by distributing computation and overlapping communication.
  • Utilized in Large World Model (LWM) for million-length vision-language training.
  • Offers configurable query_chunk_size and key_chunk_size for performance tuning.

Maintenance & Community

The project is associated with research from UC Berkeley and Stanford. Key papers are cited for reference. No specific community channels (Discord/Slack) or active development roadmap are mentioned in the README.

Licensing & Compatibility

The repository does not explicitly state a license in the provided README. This requires further investigation for commercial use or integration into closed-source projects.

Limitations & Caveats

The implementation is a Jax-specific solution, limiting its direct applicability to users not working within the Jax ecosystem. The README does not detail specific hardware requirements beyond the need for distributed setups, nor does it mention performance benchmarks against other large-context attention mechanisms.

Health Check
Last commit

6 months ago

Responsiveness

Inactive

Pull Requests (30d)
0
Issues (30d)
0
Star History
13 stars in the last 90 days

Explore Similar Projects

Feedback? Help us improve.