Efficient CUDA kernels for MLA decoding
Top 4.4% on sourcepulse
FlashMLA provides highly optimized CUDA kernels for efficient Multi-Head Linear Attention (MLA) decoding on NVIDIA Hopper GPUs, targeting large language model inference. It offers significant speedups for compute-bound workloads by leveraging techniques like paged KV cache and optimized tiling, benefiting researchers and engineers working on high-throughput LLM serving.
How It Works
FlashMLA implements MLA decoding kernels optimized for Hopper architectures, supporting BF16 and FP16 precision. It utilizes paged KV cache with a block size of 64 and employs advanced tiling strategies inspired by FlashAttention and CUTLASS to maximize throughput. This approach is advantageous for compute-intensive scenarios where the number of query heads multiplied by tokens per request exceeds 64, achieving high TFLOPS.
Quick Start & Requirements
python setup.py install
python tests/test_flash_mla.py
Highlighted Details
Maintenance & Community
The project is actively maintained by DeepSeek AI. It acknowledges inspiration from FlashAttention 2&3 and CUTLASS. Community versions are available for MetaX, Moore Threads, Hygon DCU, Intellifusion, Iluvatar Corex, and AMD Instinct GPUs.
Licensing & Compatibility
The repository does not explicitly state a license in the provided README. Users should verify licensing for commercial use or integration into closed-source projects.
Limitations & Caveats
The new kernel primarily targets compute-intensive settings; for memory-bound cases, version b31bfe7 is recommended. Compatibility with older CUDA versions or non-Hopper NVIDIA architectures is not specified.
1 day ago
1 day