flax  by google

NN library for JAX, designed for flexibility in neural network research

created 5 years ago
6,714 stars

Top 7.7% on sourcepulse

GitHubView on GitHub
Project Summary

Flax is a neural network library and ecosystem for JAX, designed for flexibility and ease of use in research. It targets researchers and developers working with JAX, offering a Pythonic API for building, inspecting, and debugging neural networks, with benefits like simplified state management and mutability.

How It Works

Flax NNX, the latest iteration, introduces first-class support for Python reference semantics, allowing models to be expressed as regular Python objects. This enables natural reference sharing and mutability, simplifying complex model architectures and state management compared to more rigid functional approaches. It builds upon the earlier Flax Linen API, integrating core neural network components and utilities for distributed training and serialization.

Quick Start & Requirements

Highlighted Details

  • Offers a comprehensive neural network API including Linear, Conv, BatchNorm, LayerNorm, Attention, LSTMCell, and GRUCell.
  • Includes utilities for replicated training, serialization, checkpointing, metrics, and on-device prefetching.
  • Provides educational examples such as MNIST, Gemma language model inference, and Transformer LM1B.
  • Developed in close collaboration with the JAX team.

Maintenance & Community

Licensing & Compatibility

  • Apache 2.0 License.
  • Permissive license suitable for commercial use and integration into closed-source projects.

Limitations & Caveats

The project is actively developed, with NNX being a newer API. While the core API is expected to remain stable, users should be aware of potential evolution and follow changelogs for any deprecations or minor breaking changes.

Health Check
Last commit

17 hours ago

Responsiveness

Inactive

Pull Requests (30d)
51
Issues (30d)
11
Star History
210 stars in the last 90 days

Explore Similar Projects

Starred by Chip Huyen Chip Huyen(Author of AI Engineering, Designing Machine Learning Systems), Alex Cheema Alex Cheema(Cofounder of EXO Labs), and
1 more.

recurrent-pretraining by seal-rg

0.1%
806
Pretraining code for depth-recurrent language model research
created 5 months ago
updated 2 weeks ago
Starred by Jiayi Pan Jiayi Pan(Author of SWE-Gym; AI Researcher at UC Berkeley), Thomas Wolf Thomas Wolf(Cofounder of Hugging Face), and
3 more.

levanter by stanford-crfm

0.5%
628
Framework for training foundation models with JAX
created 3 years ago
updated 17 hours ago
Starred by Jiayi Pan Jiayi Pan(Author of SWE-Gym; AI Researcher at UC Berkeley), Chip Huyen Chip Huyen(Author of AI Engineering, Designing Machine Learning Systems), and
6 more.

EasyLM by young-geng

0.2%
2k
LLM training/finetuning framework in JAX/Flax
created 2 years ago
updated 11 months ago
Feedback? Help us improve.