JAX implementation of test-time training RNN research paper
Top 71.2% on sourcepulse
This repository provides the official JAX implementation for "Learning to (Learn at Test Time): RNNs with Expressive Hidden States." It addresses the limitations of traditional RNNs in long-context modeling by introducing a novel "Test-Time Training" (TTT) layer where the hidden state is a self-supervised learning model itself. This allows for linear complexity with expressive hidden states, benefiting researchers and practitioners working with long sequence data.
How It Works
The core innovation lies in the TTT layers (TTT-Linear and TTT-MLP), which replace standard RNN hidden states. Instead of a fixed representation, the hidden state is a trainable machine learning model (a linear model or a two-layer MLP). The update rule for this hidden state is a step of self-supervised learning, allowing it to adapt and learn from test sequences. This approach aims to achieve the linear complexity of RNNs while overcoming their expressive power limitations in long contexts.
Quick Start & Requirements
pip install -r requirements/gpu_requirements.txt
. For TPU, use pip install -r requirements/tpu_requirements.txt
.dataset_path
.Highlighted Details
Maintenance & Community
Licensing & Compatibility
Limitations & Caveats
The repository requires significant setup, including downloading large datasets and potentially configuring model sharding for larger models. The lack of explicit licensing information and community channels may pose adoption challenges.
11 months ago
1 week