vision_transformer  by google-research

Vision Transformer and MLP-Mixer models in JAX/Flax

created 4 years ago
11,632 stars

Top 4.4% on sourcepulse

GitHubView on GitHub
Project Summary

This repository provides JAX/Flax implementations for Vision Transformer (ViT) and MLP-Mixer architectures, along with pre-trained models and fine-tuning scripts. It targets researchers and practitioners in computer vision who want to leverage state-of-the-art transformer-based models for image recognition tasks. The primary benefit is access to a wide range of pre-trained models and a flexible framework for experimentation and fine-tuning.

How It Works

The core approach involves splitting images into fixed-size patches, linearly embedding them, adding positional embeddings, and processing the sequence through a standard Transformer encoder (for ViT) or specialized token-mixing and channel-mixing MLPs (for MLP-Mixer). This patch-based processing allows these models to handle images as sequences, enabling the application of powerful sequence modeling techniques to vision tasks. The use of JAX/Flax facilitates efficient computation on accelerators like GPUs and TPUs.

Quick Start & Requirements

  • Installation: pip install -r vit_jax/requirements.txt (for GPU) or pip install -r vit_jax/requirements-tpu.txt (for TPU). Requires Python >= 3.10. Flaxformer installation is also needed.
  • Prerequisites: JAX, Flaxformer, Python 3.10+.
  • Fine-tuning: Example commands provided for fine-tuning ViT and MLP-Mixer models on custom datasets.
  • Resources: Pre-trained models are available via Google Cloud Storage (GCS) buckets. Detailed setup instructions for Google Cloud VMs with GPUs or TPUs are included.
  • Demos: Interactive Colab notebooks are available for exploring models and fine-tuning:

Highlighted Details

  • Offers over 50,000 ViT and hybrid checkpoints from the "How to train your ViT?" paper, with varying data augmentation and regularization.
  • Includes implementations and pre-trained models for MLP-Mixer, an all-MLP architecture for vision.
  • Provides models and code for LiT (Locked-image text Tuning) for zero-shot transfer capabilities.
  • Extensive benchmarking results are provided for various model sizes, datasets, and training configurations.

Maintenance & Community

This repository is associated with multiple research papers from Google Research. While specific community channels like Discord/Slack are not mentioned, the project is linked to the google-research GitHub organization.

Licensing & Compatibility

The repository is released under an unspecified open-source license. However, a disclaimer states it's "not an official Google product" and was "forked and modified from google-research/big_transfer." Compatibility for commercial use or closed-source linking would require clarification of the specific license.

Limitations & Caveats

The README notes that Google Colab's single GPU/TPU setup can lead to slow training speeds, recommending dedicated machines for non-trivial fine-tuning. Integration of custom datasets requires modifications to vit_jax/input_pipeline.py. The LiT models currently only support evaluation code within this repository; training code is in big_vision.

Health Check
Last commit

4 months ago

Responsiveness

Inactive

Pull Requests (30d)
0
Issues (30d)
0
Star History
380 stars in the last 90 days

Explore Similar Projects

Starred by Jeremy Howard Jeremy Howard(Cofounder of fast.ai) and Stas Bekman Stas Bekman(Author of Machine Learning Engineering Open Book; Research Engineer at Snowflake).

SwissArmyTransformer by THUDM

0.3%
1k
Transformer library for flexible model development
created 3 years ago
updated 7 months ago
Feedback? Help us improve.