FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
作者: Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao
分类: cs.LG, cs.AI
发布日期: 2024-07-11 (更新: 2024-07-12)
💡 一句话要点
FlashAttention-3:通过异步和低精度加速Transformer Attention计算。
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: FlashAttention Transformer Attention机制 GPU加速 低精度计算 异步计算 Hopper GPU 大型语言模型
📋 核心要点
- 现有FlashAttention系列方法在最新GPU硬件上的利用率不足,未能充分发挥硬件性能。
- FlashAttention-3利用异步计算、块状交错和低精度量化等技术,优化数据移动和计算流程。
- 实验表明,FlashAttention-3在H100 GPU上实现了显著加速,并降低了低精度计算的数值误差。
📝 摘要(中文)
Attention机制是Transformer架构的核心,也是大型语言模型和长上下文应用中的瓶颈。FlashAttention通过最小化内存读写来加速GPU上的Attention计算。然而,它尚未充分利用最新硬件中的新功能,FlashAttention-2在H100 GPU上的利用率仅为35%。本文提出了三种主要技术来加速Hopper GPU上的Attention计算:利用Tensor Cores和TMA的异步性,(1)通过warp-specialization重叠整体计算和数据移动,(2)交错块状matmul和softmax操作,以及(3)利用硬件对FP8低精度支持的块量化和非相干处理。实验表明,FlashAttention-3在H100 GPU上实现了1.5-2.0倍的加速,FP16精度下达到740 TFLOPs/s(75%利用率),FP8精度下接近1.2 PFLOPs/s。验证结果表明,FP8 FlashAttention-3的数值误差比基线FP8 Attention低2.6倍。
🔬 方法详解
问题定义:论文旨在解决Transformer模型中Attention机制计算效率低下的问题,尤其是在长序列和大型模型中。现有FlashAttention系列方法虽然在减少内存读写方面有所改进,但未能充分利用最新GPU硬件(如H100)的特性,导致硬件利用率不高,计算速度受限。
核心思路:论文的核心思路是利用Hopper GPU架构中的异步计算能力(Tensor Cores和TMA)以及对低精度数据类型的硬件加速支持,通过优化数据移动和计算流程,实现更高的硬件利用率和更快的Attention计算速度。
技术框架:FlashAttention-3的技术框架主要包括以下几个部分:1) Warp-specialization:利用warp内的线程协同完成数据加载和计算,实现计算和数据移动的重叠。2) 块状交错:将块状矩阵乘法和softmax操作交错执行,减少中间结果的内存读写。3) 块量化和非相干处理:利用硬件对FP8低精度数据类型的支持,对数据进行量化,并在块级别进行非相干处理,进一步提高计算效率。
关键创新:FlashAttention-3的关键创新在于充分利用了Hopper GPU架构的异步计算能力和低精度加速特性,通过warp-specialization、块状交错和块量化等技术,实现了计算和数据移动的重叠,减少了内存读写,提高了硬件利用率。与现有FlashAttention方法相比,FlashAttention-3更加充分地利用了硬件资源,实现了更高的计算速度。
关键设计:在warp-specialization中,需要合理分配warp内的线程,以实现最佳的计算和数据移动重叠效果。在块状交错中,需要仔细调整块的大小和交错方式,以最大程度地减少内存读写。在块量化中,需要选择合适的量化策略,以保证计算精度。此外,还需要针对不同的硬件平台进行优化,以充分发挥硬件性能。
🖼️ 关键图片
📊 实验亮点
FlashAttention-3在H100 GPU上实现了显著的性能提升。在FP16精度下,达到了740 TFLOPs/s的计算速度,硬件利用率达到75%。在FP8精度下,计算速度接近1.2 PFLOPs/s。此外,FP8 FlashAttention-3的数值误差比基线FP8 Attention低2.6倍,表明其在低精度计算下具有更高的精度。
🎯 应用场景
FlashAttention-3的潜在应用领域包括大型语言模型、长文本处理、图像识别、语音识别等。通过提高Attention计算的效率,FlashAttention-3可以加速这些应用的训练和推理过程,降低计算成本,并支持更大规模的模型和更长的上下文长度。这对于提高模型的性能和扩展应用范围具有重要意义。
📄 摘要(原文)
Attention, as a core layer of the ubiquitous Transformer architecture, is the bottleneck for large language models and long-context applications. FlashAttention elaborated an approach to speed up attention on GPUs through minimizing memory reads/writes. However, it has yet to take advantage of new capabilities present in recent hardware, with FlashAttention-2 achieving only 35% utilization on the H100 GPU. We develop three main techniques to speed up attention on Hopper GPUs: exploiting asynchrony of the Tensor Cores and TMA to (1) overlap overall computation and data movement via warp-specialization and (2) interleave block-wise matmul and softmax operations, and (3) block quantization and incoherent processing that leverages hardware support for FP8 low-precision. We demonstrate that our method, FlashAttention-3, achieves speedup on H100 GPUs by 1.5-2.0$\times$ with FP16 reaching up to 740 TFLOPs/s (75% utilization), and with FP8 reaching close to 1.2 PFLOPs/s. We validate that FP8 FlashAttention-3 achieves 2.6$\times$ lower numerical error than a baseline FP8 attention.