Vision Transformer and MLP-Mixer models in JAX/Flax
Top 4.4% on sourcepulse
This repository provides JAX/Flax implementations for Vision Transformer (ViT) and MLP-Mixer architectures, along with pre-trained models and fine-tuning scripts. It targets researchers and practitioners in computer vision who want to leverage state-of-the-art transformer-based models for image recognition tasks. The primary benefit is access to a wide range of pre-trained models and a flexible framework for experimentation and fine-tuning.
How It Works
The core approach involves splitting images into fixed-size patches, linearly embedding them, adding positional embeddings, and processing the sequence through a standard Transformer encoder (for ViT) or specialized token-mixing and channel-mixing MLPs (for MLP-Mixer). This patch-based processing allows these models to handle images as sequences, enabling the application of powerful sequence modeling techniques to vision tasks. The use of JAX/Flax facilitates efficient computation on accelerators like GPUs and TPUs.
Quick Start & Requirements
pip install -r vit_jax/requirements.txt
(for GPU) or pip install -r vit_jax/requirements-tpu.txt
(for TPU). Requires Python >= 3.10. Flaxformer installation is also needed.Highlighted Details
Maintenance & Community
This repository is associated with multiple research papers from Google Research. While specific community channels like Discord/Slack are not mentioned, the project is linked to the google-research
GitHub organization.
Licensing & Compatibility
The repository is released under an unspecified open-source license. However, a disclaimer states it's "not an official Google product" and was "forked and modified from google-research/big_transfer." Compatibility for commercial use or closed-source linking would require clarification of the specific license.
Limitations & Caveats
The README notes that Google Colab's single GPU/TPU setup can lead to slow training speeds, recommending dedicated machines for non-trivial fine-tuning. Integration of custom datasets requires modifications to vit_jax/input_pipeline.py
. The LiT models currently only support evaluation code within this repository; training code is in big_vision
.
4 months ago
Inactive