PyTorch model for sequence modeling layers with expressive hidden states
Top 32.5% on sourcepulse
This repository provides the official PyTorch implementation for "Learning to (Learn at Test Time): RNNs with Expressive Hidden States." It addresses the limitations of traditional RNNs in long-context modeling by introducing Test-Time Training (TTT) layers, which use a machine learning model as their hidden state, updated via self-supervised learning. This approach offers linear complexity with enhanced expressive power for sequence modeling tasks, targeting researchers and practitioners interested in advanced RNN architectures.
How It Works
The core innovation lies in the TTT layers (TTT-Linear and TTT-MLP), where the hidden state is itself a trainable ML model (a linear model or a two-layer MLP). The update rule for this hidden state involves a step of self-supervised learning, allowing the model to adapt and learn even during inference on test sequences. This "test-time training" mechanism aims to overcome the expressiveness limitations of standard RNN hidden states while maintaining linear computational complexity, making it suitable for long-context scenarios where self-attention's quadratic complexity is prohibitive.
Quick Start & Requirements
pip install "transformers[torch]"
AutoTokenizer
and TTTForCausalLM
. See the provided Python code snippet for a quick start example.Highlighted Details
Maintenance & Community
The project is associated with the "test-time-training" organization. Further community engagement details, such as Discord/Slack channels or roadmaps, are not explicitly mentioned in the README.
Licensing & Compatibility
The README does not explicitly state a license. Compatibility for commercial use or closed-source linking would require clarification on the licensing terms.
Limitations & Caveats
The README explicitly states this is a "naive implementation of TTT layers for tutorial purposes" and is not recommended for training due to a lack of systems optimization, leading to slow training times. For training and replicating paper results, the JAX codebase is recommended.
1 year ago
1 day