tpu-starter  by ayaka14732

TPU guide for JAX-based ML workflows on Google Cloud

created 3 years ago
536 stars

Top 60.0% on sourcepulse

GitHubView on GitHub
Project Summary

This repository provides a comprehensive guide for users looking to leverage Google Cloud TPUs for machine learning tasks. It targets researchers and engineers who want to understand TPU capabilities, set up development environments, and optimize their workflows, offering a path to free TPU resources and detailed instructions for both single TPU VMs and multi-host TPU Pods.

How It Works

The project focuses on the JAX framework, highlighting its strong compatibility and performance on TPUs, contrasting it with PyTorch's limited support. It details the process of provisioning TPU VMs and Pods on Google Cloud Platform, configuring SSH access, setting up development environments with Python 3.12 and JAX, and utilizing tools like Byobu for continuous execution and VSCode for remote development. The guide also covers essential JAX and Optax best practices for efficient model training.

Quick Start & Requirements

  • Install/Run: Primarily involves using gcloud commands for TPU provisioning and ssh for remote access. JAX installation is via pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html.
  • Prerequisites: Google Cloud account, gcloud CLI, SSH client. Specific TPU VM/Pod types (e.g., v3-8, v3-32) are detailed. Python 3.12 is recommended.
  • Setup: Requires provisioning cloud resources, which can take several minutes. Detailed setup for development environments within the TPU VM/Pod is provided.
  • Links: TRC program homepage, Shawn's TRC article.

Highlighted Details

  • Detailed instructions for setting up multi-host TPU Pods, including NFS configuration and distributed command execution with podrun.
  • Best practices for JAX, random number generation, array conversions, and using the Optax library.
  • Troubleshooting common issues like TCMalloc conflicts and libtpu.so usage errors.
  • Guidance on preferring TPU VMs over older TPU nodes and GCP over Google Colab for greater flexibility.

Maintenance & Community

  • The project is community-maintained, inspired by the Cloud Run FAQ.
  • Google's official Discord server has a #tpu-research-cloud channel for community support.

Licensing & Compatibility

  • The repository itself does not specify a license, but it guides users on using Google Cloud services and open-source libraries like JAX and Optax, which have their own permissive licenses (e.g., Apache 2.0 for JAX).

Limitations & Caveats

  • TPU VMs can be rebooted during maintenance, requiring mechanisms for resuming training.
  • A single TPU core can only be used by one process at a time, necessitating careful process management.
  • JAX does not support the fork multiprocessing strategy, requiring spawn or forkserver.
  • PyTorch support on TPUs is explicitly stated as poor.
Health Check
Last commit

1 year ago

Responsiveness

1 week

Pull Requests (30d)
0
Issues (30d)
0
Star History
10 stars in the last 90 days

Explore Similar Projects

Feedback? Help us improve.