GradCache  by luyug

Technique for contrastive learning beyond GPU memory limits

created 4 years ago
399 stars

Top 73.5% on sourcepulse

GitHubView on GitHub
1 Expert Loves This Project
Project Summary

This repository provides Gradient Cache, a technique to scale contrastive learning batch sizes beyond GPU memory limits, enabling training on single GPUs that previously required multiple. It targets researchers and engineers working with large-scale contrastive learning models, offering significant cost and hardware efficiency benefits.

How It Works

Gradient Cache implements a memory-efficient gradient computation strategy by splitting large batches into smaller chunks. It performs forward and backward passes on these chunks sequentially, caching intermediate activations. These cached activations are then used to reconstruct the full-batch gradients, effectively simulating a much larger batch size without the memory overhead. This approach allows for training with significantly larger effective batch sizes on limited hardware.

Quick Start & Requirements

  • Install via pip: git clone https://github.com/luyug/GradCache && cd GradCache && pip install .
  • Requires PyTorch or JAX.
  • Supports automatic mixed precision (AMP) with torch.cuda.amp.GradScaler.
  • Official documentation and examples are available within the repository.

Highlighted Details

  • Enables training large batch contrastive learning models on a single GPU.
  • Supports both PyTorch and JAX frameworks.
  • Integrates with Huggingface Transformers models.
  • Offers a functional API with decorators for easier integration into new projects.
  • Includes support for distributed training scenarios.

Maintenance & Community

The project is associated with authors Luyu Gao, Yunyi Zhang, Jiawei Han, and Jamie Callan. Further community engagement channels are not explicitly mentioned in the README.

Licensing & Compatibility

The repository does not explicitly state a license in the provided README text. Users should verify licensing for commercial or closed-source use.

Limitations & Caveats

The README mentions that generic input types not explicitly handled may require a custom split_input_fn. The effectiveness of chunk_sizes depends on GPU memory utilization, requiring tuning.

Health Check
Last commit

1 year ago

Responsiveness

1 day

Pull Requests (30d)
0
Issues (30d)
0
Star History
15 stars in the last 90 days

Explore Similar Projects

Starred by Stas Bekman Stas Bekman(Author of Machine Learning Engineering Open Book; Research Engineer at Snowflake).

fms-fsdp by foundation-model-stack

0.4%
258
Efficiently train foundation models with PyTorch
created 1 year ago
updated 1 week ago
Starred by Jeff Hammerbacher Jeff Hammerbacher(Cofounder of Cloudera) and Stas Bekman Stas Bekman(Author of Machine Learning Engineering Open Book; Research Engineer at Snowflake).

InternEvo by InternLM

1.0%
402
Lightweight training framework for model pre-training
created 1 year ago
updated 1 week ago
Starred by Jeremy Howard Jeremy Howard(Cofounder of fast.ai) and Stas Bekman Stas Bekman(Author of Machine Learning Engineering Open Book; Research Engineer at Snowflake).

SwissArmyTransformer by THUDM

0.3%
1k
Transformer library for flexible model development
created 3 years ago
updated 7 months ago
Feedback? Help us improve.