ttt-lm-pytorch  by test-time-training

PyTorch model for sequence modeling layers with expressive hidden states

created 1 year ago
1,238 stars

Top 32.5% on sourcepulse

GitHubView on GitHub
Project Summary

This repository provides the official PyTorch implementation for "Learning to (Learn at Test Time): RNNs with Expressive Hidden States." It addresses the limitations of traditional RNNs in long-context modeling by introducing Test-Time Training (TTT) layers, which use a machine learning model as their hidden state, updated via self-supervised learning. This approach offers linear complexity with enhanced expressive power for sequence modeling tasks, targeting researchers and practitioners interested in advanced RNN architectures.

How It Works

The core innovation lies in the TTT layers (TTT-Linear and TTT-MLP), where the hidden state is itself a trainable ML model (a linear model or a two-layer MLP). The update rule for this hidden state involves a step of self-supervised learning, allowing the model to adapt and learn even during inference on test sequences. This "test-time training" mechanism aims to overcome the expressiveness limitations of standard RNN hidden states while maintaining linear computational complexity, making it suitable for long-context scenarios where self-attention's quadratic complexity is prohibitive.

Quick Start & Requirements

  • Install: pip install "transformers[torch]"
  • Prerequisites: PyTorch, Hugging Face Transformers.
  • Usage: Load models and generate text using Hugging Face AutoTokenizer and TTTForCausalLM. See the provided Python code snippet for a quick start example.
  • Resources: The repository links to faster inference kernels and benchmarks.

Highlighted Details

  • Official PyTorch implementation of the TTT paper.
  • Integrates with Hugging Face Transformers for easy model loading and generation.
  • Introduces TTT-Linear and TTT-MLP layers with linear complexity and expressive hidden states.
  • Focuses on inference capabilities, with a separate JAX codebase recommended for training.

Maintenance & Community

The project is associated with the "test-time-training" organization. Further community engagement details, such as Discord/Slack channels or roadmaps, are not explicitly mentioned in the README.

Licensing & Compatibility

The README does not explicitly state a license. Compatibility for commercial use or closed-source linking would require clarification on the licensing terms.

Limitations & Caveats

The README explicitly states this is a "naive implementation of TTT layers for tutorial purposes" and is not recommended for training due to a lack of systems optimization, leading to slow training times. For training and replicating paper results, the JAX codebase is recommended.

Health Check
Last commit

1 year ago

Responsiveness

1 day

Pull Requests (30d)
1
Issues (30d)
0
Star History
59 stars in the last 90 days

Explore Similar Projects

Starred by Andrej Karpathy Andrej Karpathy(Founder of Eureka Labs; Formerly at Tesla, OpenAI; Author of CS 231n) and Georgios Konstantopoulos Georgios Konstantopoulos(CTO, General Partner at Paradigm).

mlx-gpt2 by pranavjad

0.5%
393
Minimal GPT-2 implementation for educational purposes
created 1 year ago
updated 1 year ago
Starred by George Hotz George Hotz(Author of tinygrad; Founder of the tiny corp, comma.ai), Daniel Gross Daniel Gross(Cofounder of Safe Superintelligence), and
13 more.

RWKV-LM by BlinkDL

0.2%
14k
RNN for LLM, transformer-level performance, parallelizable training
created 4 years ago
updated 1 week ago
Feedback? Help us improve.