ttt-lm-pytorch  by test-time-training

PyTorch model for sequence modeling layers with expressive hidden states

Created 1 year ago
1,254 stars

Top 31.6% on SourcePulse

GitHubView on GitHub
1 Expert Loves This Project
Project Summary

This repository provides the official PyTorch implementation for "Learning to (Learn at Test Time): RNNs with Expressive Hidden States." It addresses the limitations of traditional RNNs in long-context modeling by introducing Test-Time Training (TTT) layers, which use a machine learning model as their hidden state, updated via self-supervised learning. This approach offers linear complexity with enhanced expressive power for sequence modeling tasks, targeting researchers and practitioners interested in advanced RNN architectures.

How It Works

The core innovation lies in the TTT layers (TTT-Linear and TTT-MLP), where the hidden state is itself a trainable ML model (a linear model or a two-layer MLP). The update rule for this hidden state involves a step of self-supervised learning, allowing the model to adapt and learn even during inference on test sequences. This "test-time training" mechanism aims to overcome the expressiveness limitations of standard RNN hidden states while maintaining linear computational complexity, making it suitable for long-context scenarios where self-attention's quadratic complexity is prohibitive.

Quick Start & Requirements

  • Install: pip install "transformers[torch]"
  • Prerequisites: PyTorch, Hugging Face Transformers.
  • Usage: Load models and generate text using Hugging Face AutoTokenizer and TTTForCausalLM. See the provided Python code snippet for a quick start example.
  • Resources: The repository links to faster inference kernels and benchmarks.

Highlighted Details

  • Official PyTorch implementation of the TTT paper.
  • Integrates with Hugging Face Transformers for easy model loading and generation.
  • Introduces TTT-Linear and TTT-MLP layers with linear complexity and expressive hidden states.
  • Focuses on inference capabilities, with a separate JAX codebase recommended for training.

Maintenance & Community

The project is associated with the "test-time-training" organization. Further community engagement details, such as Discord/Slack channels or roadmaps, are not explicitly mentioned in the README.

Licensing & Compatibility

The README does not explicitly state a license. Compatibility for commercial use or closed-source linking would require clarification on the licensing terms.

Limitations & Caveats

The README explicitly states this is a "naive implementation of TTT layers for tutorial purposes" and is not recommended for training due to a lack of systems optimization, leading to slow training times. For training and replicating paper results, the JAX codebase is recommended.

Health Check
Last Commit

1 year ago

Responsiveness

Inactive

Pull Requests (30d)
0
Issues (30d)
0
Star History
14 stars in the last 30 days

Explore Similar Projects

Starred by Wing Lian Wing Lian(Founder of Axolotl AI), Chip Huyen Chip Huyen(Author of "AI Engineering", "Designing Machine Learning Systems"), and
2 more.

recurrent-pretraining by seal-rg

0%
827
Pretraining code for depth-recurrent language model research
Created 7 months ago
Updated 1 week ago
Starred by Shizhe Diao Shizhe Diao(Author of LMFlow; Research Scientist at NVIDIA), Tri Dao Tri Dao(Chief Scientist at Together AI), and
1 more.

hnet by goombalab

1.5%
722
Hierarchical sequence modeling with dynamic chunking
Created 2 months ago
Updated 1 month ago
Starred by Elie Bursztein Elie Bursztein(Cybersecurity Lead at Google DeepMind), Omar Khattab Omar Khattab(Coauthor of DSPy, ColBERT; Professor at MIT), and
15 more.

gpt-neo by EleutherAI

0.0%
8k
GPT-2/3-style model implementation using mesh-tensorflow
Created 5 years ago
Updated 3 years ago
Feedback? Help us improve.