GradCache  by luyug

Technique for contrastive learning beyond GPU memory limits

Created 4 years ago
408 stars

Top 71.5% on SourcePulse

GitHubView on GitHub
Project Summary

This repository provides Gradient Cache, a technique to scale contrastive learning batch sizes beyond GPU memory limits, enabling training on single GPUs that previously required multiple. It targets researchers and engineers working with large-scale contrastive learning models, offering significant cost and hardware efficiency benefits.

How It Works

Gradient Cache implements a memory-efficient gradient computation strategy by splitting large batches into smaller chunks. It performs forward and backward passes on these chunks sequentially, caching intermediate activations. These cached activations are then used to reconstruct the full-batch gradients, effectively simulating a much larger batch size without the memory overhead. This approach allows for training with significantly larger effective batch sizes on limited hardware.

Quick Start & Requirements

  • Install via pip: git clone https://github.com/luyug/GradCache && cd GradCache && pip install .
  • Requires PyTorch or JAX.
  • Supports automatic mixed precision (AMP) with torch.cuda.amp.GradScaler.
  • Official documentation and examples are available within the repository.

Highlighted Details

  • Enables training large batch contrastive learning models on a single GPU.
  • Supports both PyTorch and JAX frameworks.
  • Integrates with Huggingface Transformers models.
  • Offers a functional API with decorators for easier integration into new projects.
  • Includes support for distributed training scenarios.

Maintenance & Community

The project is associated with authors Luyu Gao, Yunyi Zhang, Jiawei Han, and Jamie Callan. Further community engagement channels are not explicitly mentioned in the README.

Licensing & Compatibility

The repository does not explicitly state a license in the provided README text. Users should verify licensing for commercial or closed-source use.

Limitations & Caveats

The README mentions that generic input types not explicitly handled may require a custom split_input_fn. The effectiveness of chunk_sizes depends on GPU memory utilization, requiring tuning.

Health Check
Last Commit

1 year ago

Responsiveness

Inactive

Pull Requests (30d)
0
Issues (30d)
0
Star History
6 stars in the last 30 days

Explore Similar Projects

Starred by Andrej Karpathy Andrej Karpathy(Founder of Eureka Labs; Formerly at Tesla, OpenAI; Author of CS 231n), Pawel Garbacki Pawel Garbacki(Cofounder of Fireworks AI), and
11 more.

Liger-Kernel by linkedin

0.6%
6k
Triton kernels for efficient LLM training
Created 1 year ago
Updated 1 day ago
Starred by François Chollet François Chollet(Author of Keras; Cofounder of Ndea, ARC Prize), Chaoyu Yang Chaoyu Yang(Founder of Bento), and
13 more.

neon by NervanaSystems

0%
4k
Deep learning framework (discontinued)
Created 11 years ago
Updated 4 years ago
Starred by Andrej Karpathy Andrej Karpathy(Founder of Eureka Labs; Formerly at Tesla, OpenAI; Author of CS 231n), Jiayi Pan Jiayi Pan(Author of SWE-Gym; MTS at xAI), and
34 more.

flash-attention by Dao-AILab

0.6%
20k
Fast, memory-efficient attention implementation
Created 3 years ago
Updated 1 day ago
Feedback? Help us improve.