vision_transformer  by google-research

Vision Transformer and MLP-Mixer models in JAX/Flax

Created 5 years ago
12,317 stars

Top 4.1% on SourcePulse

GitHubView on GitHub
Project Summary

This repository provides JAX/Flax implementations for Vision Transformer (ViT) and MLP-Mixer architectures, along with pre-trained models and fine-tuning scripts. It targets researchers and practitioners in computer vision who want to leverage state-of-the-art transformer-based models for image recognition tasks. The primary benefit is access to a wide range of pre-trained models and a flexible framework for experimentation and fine-tuning.

How It Works

The core approach involves splitting images into fixed-size patches, linearly embedding them, adding positional embeddings, and processing the sequence through a standard Transformer encoder (for ViT) or specialized token-mixing and channel-mixing MLPs (for MLP-Mixer). This patch-based processing allows these models to handle images as sequences, enabling the application of powerful sequence modeling techniques to vision tasks. The use of JAX/Flax facilitates efficient computation on accelerators like GPUs and TPUs.

Quick Start & Requirements

  • Installation: pip install -r vit_jax/requirements.txt (for GPU) or pip install -r vit_jax/requirements-tpu.txt (for TPU). Requires Python >= 3.10. Flaxformer installation is also needed.
  • Prerequisites: JAX, Flaxformer, Python 3.10+.
  • Fine-tuning: Example commands provided for fine-tuning ViT and MLP-Mixer models on custom datasets.
  • Resources: Pre-trained models are available via Google Cloud Storage (GCS) buckets. Detailed setup instructions for Google Cloud VMs with GPUs or TPUs are included.
  • Demos: Interactive Colab notebooks are available for exploring models and fine-tuning:

Highlighted Details

  • Offers over 50,000 ViT and hybrid checkpoints from the "How to train your ViT?" paper, with varying data augmentation and regularization.
  • Includes implementations and pre-trained models for MLP-Mixer, an all-MLP architecture for vision.
  • Provides models and code for LiT (Locked-image text Tuning) for zero-shot transfer capabilities.
  • Extensive benchmarking results are provided for various model sizes, datasets, and training configurations.

Maintenance & Community

This repository is associated with multiple research papers from Google Research. While specific community channels like Discord/Slack are not mentioned, the project is linked to the google-research GitHub organization.

Licensing & Compatibility

The repository is released under an unspecified open-source license. However, a disclaimer states it's "not an official Google product" and was "forked and modified from google-research/big_transfer." Compatibility for commercial use or closed-source linking would require clarification of the specific license.

Limitations & Caveats

The README notes that Google Colab's single GPU/TPU setup can lead to slow training speeds, recommending dedicated machines for non-trivial fine-tuning. Integration of custom datasets requires modifications to vit_jax/input_pipeline.py. The LiT models currently only support evaluation code within this repository; training code is in big_vision.

Health Check
Last Commit

3 weeks ago

Responsiveness

1 week

Pull Requests (30d)
2
Issues (30d)
1
Star History
81 stars in the last 30 days

Explore Similar Projects

Starred by Andrej Karpathy Andrej Karpathy(Founder of Eureka Labs; Formerly at Tesla, OpenAI; Author of CS 231n), Jiayi Pan Jiayi Pan(Author of SWE-Gym; MTS at xAI), and
15 more.

taming-transformers by CompVis

0.1%
6k
Image synthesis research paper using transformers
Created 5 years ago
Updated 1 year ago
Feedback? Help us improve.