ELF  by lillian039

JAX implementation of continuous diffusion language models

Created 2 weeks ago

New!

750 stars

Top 45.8% on SourcePulse

GitHubView on GitHub
Project Summary

ELF: Embedded Language Flows is a JAX implementation of continuous diffusion language models based on Flow Matching. It addresses the challenge of generating discrete text by operating predominantly within a continuous embedding space, simplifying the integration of techniques from image diffusion models like classifier-free guidance. This approach allows for more straightforward adaptation and potentially more fluent and efficient text generation, targeting researchers and practitioners in the LLM space.

How It Works

ELF utilizes continuous-time Flow Matching, a class of continuous diffusion models. The core innovation is maintaining data representation in a continuous embedding space throughout the diffusion process, only mapping to discrete tokens at the final time step via a shared-weight network. This design facilitates the direct application of established image diffusion model techniques, such as classifier-free guidance (CFG), and enables the model to progressively refine ungrammatical sequences into fluent text by denoising trajectories in the continuous space.

Quick Start & Requirements

  • Installation: pip install -r requirements.txt
  • Prerequisites: JAX, TPUs, and optionally Weights & Biases (WandB) for experiment tracking.
  • Checkpoints: Pre-trained checkpoints are automatically downloaded from HuggingFace (embedded-language-flows/) via the --checkpoint_path argument.
  • Documentation: Configuration files (configs/) and Python scripts (src/eval.py, src/train.py) provide detailed usage examples.

Highlighted Details

  • Models: Offers three sizes: ELF-B (105M), ELF-M (342M), and ELF-L (652M).
  • Performance: On OpenWebText (unconditional generation), ELF-B (32 steps) achieves a perplexity (Gen. PPL) of 24.1 and entropy of 5.15.
  • Conditional Tasks: For WMT14 De-En translation, ELF-B yields a BLEU score of 26.4. For XSum summarization, it achieves ROUGE-1/2/L scores of 36.0/12.2/27.8.
  • Efficiency: Demonstrates superior performance with substantially fewer training tokens compared to other discrete and continuous DLMs.
  • Hardware Focus: Developed and tested primarily on TPUs, with specific results reported on TPU v5p-64.

Maintenance & Community

No specific details regarding maintainers, community channels (e.g., Discord, Slack), or active development signals are present in the provided README.

Licensing & Compatibility

The project is released under the MIT License, which generally permits commercial use and integration into closed-source projects without significant restrictions.

Limitations & Caveats

The implementation is primarily optimized and tested for TPUs; performance on other hardware accelerators may vary. A PyTorch version is planned but not yet available. Reported metrics may show slight variations depending on the specific compute setup used for evaluation.

Health Check
Last Commit

1 week ago

Responsiveness

Inactive

Pull Requests (30d)
1
Issues (30d)
5
Star History
755 stars in the last 16 days

Explore Similar Projects

Starred by Chip Huyen Chip Huyen(Author of "AI Engineering", "Designing Machine Learning Systems"), Wing Lian Wing Lian(Founder of Axolotl AI), and
10 more.

open_flamingo by mlfoundations

0.0%
4k
Open-source framework for training large multimodal models
Created 3 years ago
Updated 1 year ago
Feedback? Help us improve.