PyTorch implementation of nGPT, a normalized GPT learning on the hypersphere
Top 92.1% on sourcepulse
This repository provides a PyTorch implementation of nGPT (normalized GPT), a Transformer variant that learns on the hypersphere. It aims to improve Transformer performance, particularly in areas like continual learning and reinforcement learning, by normalizing attention queries and keys. The project is suitable for researchers and practitioners interested in exploring novel Transformer architectures.
How It Works
nGPT modifies the standard Transformer architecture by incorporating normalized attention mechanisms, specifically normalizing the query and key vectors before the attention calculation. This approach, inspired by cosine similarity and hypersphere learning, aims to improve stability and potentially enhance expressivity by constraining representations. The implementation is a direct translation of the concepts presented in the associated research paper.
Quick Start & Requirements
pip install nGPT-pytorch
python train.py
for Enwik8 dataset.Highlighted Details
attn_norm_qk
parameter for enabling normalized attention.Maintenance & Community
The project is associated with the author "lucidrains," known for various PyTorch implementations of recent deep learning models. No specific community channels or roadmap are detailed in the README.
Licensing & Compatibility
The README does not explicitly state a license. Given the author's typical practices and the nature of such implementations, it is likely MIT or a similar permissive license, but this should be verified. Compatibility with commercial or closed-source projects is assumed to be high if it is indeed MIT licensed.
Limitations & Caveats
The implementation is described as a "quick implementation," suggesting it may not include all optimizations or features of a production-ready library. The README raises a question about potential loss of expressivity due to the normalization, which warrants further investigation by users.
2 months ago
1 day