grokfast  by ironjr

Research paper for accelerated grokking via gradient amplification

created 1 year ago
559 stars

Top 58.3% on sourcepulse

GitHubView on GitHub
Project Summary

Grokfast accelerates the "grokking" phenomenon in machine learning, where models exhibit delayed generalization after overfitting. This project offers a simple, drop-in solution for practitioners seeking to speed up this process across diverse tasks like image, language, and graph modeling.

How It Works

Grokfast operates by spectrally decomposing parameter gradients into fast and slow-varying components. It then amplifies the slow-varying components, which are hypothesized to drive generalization. This is achieved by integrating custom gradient filtering functions (EMA or MA) directly into the optimization loop, modifying gradients before the optimizer step. This approach aims to hasten the transition from overfitting to generalization without altering the core model architecture or training process.

Quick Start & Requirements

  • Install via pip install -r requirements.txt after cloning the repository.
  • Requires PyTorch.
  • Reproduction of experiments requires additional packages listed in requirements.txt.
  • Setup for basic usage involves downloading grokfast.py and importing its functions.

Highlighted Details

  • Achieves over 50x acceleration in reaching generalization milestones in experiments.
  • Demonstrates effectiveness across Transformer decoders, MLPs, LSTMs, and G-CNNs on various datasets.
  • Offers two filtering methods: gradfilter_ema (Exponential Moving Average) and gradfilter_ma (Moving Average).
  • Provides guidance on hyperparameter tuning for cutoff frequencies and weight decay.

Maintenance & Community

  • Project initiated by researchers from Seoul National University.
  • Code is based on several prior grokking research projects.
  • Contact email provided for questions.

Licensing & Compatibility

  • Licensed under the MIT License.
  • Permissive license suitable for commercial use and integration into closed-source projects.

Limitations & Caveats

The provided hyperparameter recommendations are based on experimental experience and may require further tuning for optimal performance on new tasks. The gradfilter_ma function's additional memory requirements increase linearly with window_size.

Health Check
Last commit

1 year ago

Responsiveness

1 day

Pull Requests (30d)
0
Issues (30d)
0
Star History
11 stars in the last 90 days

Explore Similar Projects

Starred by Stas Bekman Stas Bekman(Author of Machine Learning Engineering Open Book; Research Engineer at Snowflake).

applied-ai by pytorch-labs

0.3%
289
Applied AI experiments and examples for PyTorch
created 2 years ago
updated 2 months ago
Feedback? Help us improve.