ttt-lm-jax  by test-time-training

JAX implementation of test-time training RNN research paper

created 1 year ago
418 stars

Top 71.2% on sourcepulse

GitHubView on GitHub
Project Summary

This repository provides the official JAX implementation for "Learning to (Learn at Test Time): RNNs with Expressive Hidden States." It addresses the limitations of traditional RNNs in long-context modeling by introducing a novel "Test-Time Training" (TTT) layer where the hidden state is a self-supervised learning model itself. This allows for linear complexity with expressive hidden states, benefiting researchers and practitioners working with long sequence data.

How It Works

The core innovation lies in the TTT layers (TTT-Linear and TTT-MLP), which replace standard RNN hidden states. Instead of a fixed representation, the hidden state is a trainable machine learning model (a linear model or a two-layer MLP). The update rule for this hidden state is a step of self-supervised learning, allowing it to adapt and learn from test sequences. This approach aims to achieve the linear complexity of RNNs while overcoming their expressive power limitations in long contexts.

Quick Start & Requirements

  • Installation: Install GPU requirements via pip install -r requirements/gpu_requirements.txt. For TPU, use pip install -r requirements/tpu_requirements.txt.
  • Prerequisites: Python 3.11, JAX, WandB for logging. Datasets (Llama-2 tokenized) need to be downloaded from Google Cloud Buckets.
  • Setup: Requires downloading large datasets and configuring dataset_path.
  • Links: Paper, PyTorch Codebase, Model Docs

Highlighted Details

  • Official JAX implementation of TTT layers.
  • Achieves linear complexity with expressive hidden states for long-context modeling.
  • Based on EasyLM and FlashAttention dataloader.
  • Scripts provided for replicating paper experiments.

Maintenance & Community

  • No specific contributors, sponsorships, or community links (Discord/Slack) are mentioned in the README.

Licensing & Compatibility

  • The README does not explicitly state a license. Compatibility for commercial or closed-source use is not specified.

Limitations & Caveats

The repository requires significant setup, including downloading large datasets and potentially configuring model sharding for larger models. The lack of explicit licensing information and community channels may pose adoption challenges.

Health Check
Last commit

11 months ago

Responsiveness

1 week

Pull Requests (30d)
0
Issues (30d)
0
Star History
12 stars in the last 90 days

Explore Similar Projects

Starred by Chip Huyen Chip Huyen(Author of AI Engineering, Designing Machine Learning Systems), Zhuohan Li Zhuohan Li(Author of vLLM), and
1 more.

Consistency_LLM by hao-ai-lab

0%
397
Parallel decoder for efficient LLM inference
created 1 year ago
updated 8 months ago
Feedback? Help us improve.