Framework for training foundation models with JAX
Top 53.5% on sourcepulse
Levanter is a JAX-based framework for training large foundation models, prioritizing legibility, scalability, and reproducibility. It targets researchers and engineers building and experimenting with LLMs, offering a high-performance, deterministic training environment.
How It Works
Levanter leverages JAX for its high-performance, auto-vectorizing, and JIT-compiling capabilities. It utilizes the named tensor library Haliax to enable composable and readable deep learning code, abstracting away complex tensor manipulations. This approach facilitates distributed training across GPUs and TPUs, supporting techniques like Fully Sharded Data Parallelism (FSDP) and tensor parallelism.
Quick Start & Requirements
pip install levanter
or pip install -e .
after cloning the repository.python -m levanter.main.train_lm --config_path config/gpt2_nano.yaml
Highlighted Details
Maintenance & Community
#levanter
on the unofficial Jax LLM Discord.Licensing & Compatibility
Limitations & Caveats
GPU support is still in progress. Resuming training on a different number of hosts currently breaks reproducibility.
1 day ago
1 day