TransformerEngine  by NVIDIA

Library for Transformer model acceleration on NVIDIA GPUs

Created 3 years ago
2,729 stars

Top 17.4% on SourcePulse

GitHubView on GitHub
Project Summary

NVIDIA Transformer Engine is a library designed to accelerate Transformer model training and inference on NVIDIA GPUs by leveraging FP8 precision. It targets researchers and engineers working with large language models (LLMs) and other Transformer architectures, offering significant performance gains and reduced memory usage.

How It Works

Transformer Engine provides optimized building blocks and an automatic mixed-precision API for seamless integration into existing deep learning workflows. It manages FP8 scaling factors and other necessary values internally, simplifying the adoption of FP8 precision, which offers improved performance over FP16 without accuracy degradation on compatible NVIDIA hardware.

Quick Start & Requirements

  • Installation: Recommended via Docker images from NGC Catalog. Pip installation is also supported: pip install --no-build-isolation transformer_engine[pytorch,jax].
  • Prerequisites:
    • Hardware: NVIDIA GPUs with Compute Capability 8.9+ (Hopper, Ada, Blackwell) for FP8 features. Ampere and later for FP16/BF16 optimizations.
    • OS: Linux (WSL2 with limited support).
    • Software: CUDA 12.1+ (12.8+ for Blackwell), cuDNN 9.3+, GCC 9+/Clang 10+, Python 3.12.
    • Source Build: CMake 3.18+, Ninja, Git 2.17+, pybind11 2.6.0+.
  • Resources: Docker images provide pre-configured environments. Pip installation requires a C++ compiler and CUDA toolkit. Compilation with FlashAttention can be resource-intensive.
  • Docs: User Guide, Examples, Quickstart Notebook.

Highlighted Details

  • Supports FP8 precision on NVIDIA Hopper, Ada, and Blackwell GPUs for enhanced performance and memory efficiency.
  • Provides optimized kernels and an automatic mixed-precision API for PyTorch and JAX.
  • Demonstrates FP8 convergence with no significant difference in training loss curves compared to BF16 across various LLM architectures.
  • Integrated with major frameworks like DeepSpeed, Hugging Face Accelerate, and NVIDIA NeMo.

Maintenance & Community

  • Actively maintained by NVIDIA.
  • Integrations with numerous popular LLM frameworks.
  • Contributing Guide.

Licensing & Compatibility

  • License: Apache 2.0.
  • Compatible with commercial use and closed-source linking.

Limitations & Caveats

  • FP8 features require specific NVIDIA GPU architectures (Compute Capability 8.9+).
  • WSL2 support is limited.
  • FlashAttention compilation can be memory-intensive. A breaking change in v1.7 altered padding mask definitions in the PyTorch implementation.
Health Check
Last Commit

18 hours ago

Responsiveness

1 week

Pull Requests (30d)
97
Issues (30d)
18
Star History
80 stars in the last 30 days

Explore Similar Projects

Starred by Nat Friedman Nat Friedman(Former CEO of GitHub), Chip Huyen Chip Huyen(Author of "AI Engineering", "Designing Machine Learning Systems"), and
15 more.

FasterTransformer by NVIDIA

0.1%
6k
Optimized transformer library for inference
Created 4 years ago
Updated 1 year ago
Starred by Andrej Karpathy Andrej Karpathy(Founder of Eureka Labs; Formerly at Tesla, OpenAI; Author of CS 231n), Jiayi Pan Jiayi Pan(Author of SWE-Gym; MTS at xAI), and
34 more.

flash-attention by Dao-AILab

0.6%
20k
Fast, memory-efficient attention implementation
Created 3 years ago
Updated 1 day ago
Feedback? Help us improve.