nanodl  by HenryNdubuaku

Jax library for building transformer models, including GPT, Gemma, LlaMa, Mixtral, Whisper, SWin, ViT

created 1 year ago
290 stars

Top 91.7% on sourcepulse

GitHubView on GitHub
1 Expert Loves This Project
Project Summary

NanoDL is a Jax-based library for building and training transformer models from scratch, targeting AI/ML experts who need to develop smaller-scale, efficient models. It provides a pedagogical approach with modular code, enabling customization and accelerated development of neural networks with distributed training capabilities.

How It Works

NanoDL leverages Jax and Flax for efficient computation and distributed training. Its core design emphasizes modularity, with each model and its components contained in single files to minimize dependencies and facilitate learning. This approach allows users to easily select, combine, and modify layers and blocks, including specialized ones like RoPE, GQA, and MQA, for flexible model development.

Quick Start & Requirements

  • Install via pip: pip install nanodl
  • Prerequisites: Python 3.9+, JAX, Flax, Optax (GPU support recommended for training).
  • Training requires 1 to N GPUs/TPUs; CPU-only JAX supports model creation.
  • Official documentation and examples are available via Discord and clickable badges in the README.

Highlighted Details

  • Implements a wide array of transformer blocks and layers not found in Flax/Jax.
  • Includes implementations for models like Gemma, LlaMa3, Mistral, GPT3/4, T5, Whisper, ViT, and CLIP.
  • Offers data-parallel trainers for multi-GPU/TPU training and simplified data handling with custom dataloaders.
  • Accelerates classical ML models (PCA, KMeans, etc.) on GPUs/TPUs.

Maintenance & Community

  • The project encourages contributions and feedback via Discord, issues, and pull requests.
  • Experimental features like MAMBA, KAN, BitNet, GAT, and RLHF are available for direct copying from the repository.
  • The long-term goal is to build "nano" versions of popular models (<1B parameters) with competitive performance.

Licensing & Compatibility

  • The library is available under an unspecified license, but the README implies open contribution and use. Further clarification on licensing is recommended for commercial applications.

Limitations & Caveats

  • The project is explicitly stated to be in development ("still in dev, works great but roughness is expected"), with experimental features not yet packaged. Contributions are highly encouraged to address this.
Health Check
Last commit

11 months ago

Responsiveness

1 week

Pull Requests (30d)
0
Issues (30d)
0
Star History
5 stars in the last 90 days

Explore Similar Projects

Starred by Jeremy Howard Jeremy Howard(Cofounder of fast.ai) and Stas Bekman Stas Bekman(Author of Machine Learning Engineering Open Book; Research Engineer at Snowflake).

SwissArmyTransformer by THUDM

0.3%
1k
Transformer library for flexible model development
created 3 years ago
updated 7 months ago
Starred by Jiayi Pan Jiayi Pan(Author of SWE-Gym; AI Researcher at UC Berkeley), Thomas Wolf Thomas Wolf(Cofounder of Hugging Face), and
3 more.

levanter by stanford-crfm

0.5%
628
Framework for training foundation models with JAX
created 3 years ago
updated 1 day ago
Starred by Jiayi Pan Jiayi Pan(Author of SWE-Gym; AI Researcher at UC Berkeley), Chip Huyen Chip Huyen(Author of AI Engineering, Designing Machine Learning Systems), and
6 more.

EasyLM by young-geng

0.2%
2k
LLM training/finetuning framework in JAX/Flax
created 2 years ago
updated 11 months ago
Feedback? Help us improve.