INT-FlashAttention: Enabling Flash Attention for INT8 Quantization

📄 arXiv: 2409.16997v2 📥 PDF

作者: Shimao Chen, Zirui Liu, Zhiying Wu, Ce Zheng, Peizhuang Cong, Zihan Jiang, Yuhan Wu, Lei Su, Tong Yang

分类: cs.LG, cs.AI

发布日期: 2024-09-25 (更新: 2024-09-26)


💡 一句话要点

提出INT-FlashAttention,实现INT8量化加速FlashAttention推理。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: FlashAttention INT8量化 后训练量化 大语言模型 推理加速

📋 核心要点

  1. 自注意力机制作为大型语言模型的基础,面临着序列长度带来的二次时间和内存复杂度挑战。
  2. INT-FlashAttention的核心在于设计了一种与FlashAttention兼容的INT8量化方案,从而在保证精度的前提下加速推理。
  3. 实验结果表明,INT-FlashAttention相比于FP16/FP8版本的FlashAttention,推理速度提升72%,量化误差降低82%。

📝 摘要(中文)

本文提出了INT-FlashAttention,这是首个与FlashAttention前向流程兼容的INT8量化架构,显著提升了FlashAttention在Ampere GPU上的推理速度。该方法使用全INT8激活和通用矩阵乘法(GEMM)内核实现了INT-FlashAttention原型,使其成为首个具有全INT8输入的Attention算子。作为一个通用的token级别后训练量化框架,INT-FlashAttention也兼容其他数据格式,如INT4等。实验结果表明,与使用FP16和FP8数据格式的标准FlashAttention相比,INT-FlashAttention实现了72%的推理速度提升和82%的量化误差降低。

🔬 方法详解

问题定义:现有FlashAttention虽然通过优化内存访问加速了自注意力计算,但仍然面临计算量大的问题。直接应用量化方法到FlashAttention存在兼容性问题,无法充分利用硬件加速特性,导致推理速度受限。

核心思路:INT-FlashAttention的核心思路是将FlashAttention与INT8量化相结合,设计一种与FlashAttention前向流程兼容的量化方案。通过将激活值和权重都量化到INT8,并使用优化的INT8 GEMM内核,从而在保证精度的前提下显著加速推理。

技术框架:INT-FlashAttention是一个token级别的后训练量化框架。整体流程与FlashAttention类似,但在计算过程中,激活值和权重都被量化到INT8。然后,使用优化的INT8 GEMM内核进行矩阵乘法计算。最后,将结果反量化回FP16或FP8。

关键创新:INT-FlashAttention的关键创新在于其与FlashAttention前向流程的兼容性。通过精心设计量化和反量化过程,INT-FlashAttention能够充分利用FlashAttention的内存优化特性,同时利用INT8量化带来的计算加速。这是首个全INT8输入的Attention算子。

关键设计:INT-FlashAttention使用token级别的后训练量化。量化参数(例如缩放因子和零点)是针对每个token动态计算的。损失函数采用量化误差最小化策略,以保证量化后的模型精度。具体网络结构与FlashAttention保持一致,无需修改。

🖼️ 关键图片

fig_0
fig_1

📊 实验亮点

实验结果表明,INT-FlashAttention在Ampere GPU上实现了显著的性能提升。与使用FP16和FP8数据格式的标准FlashAttention相比,INT-FlashAttention实现了72%的推理速度提升和82%的量化误差降低。这表明INT-FlashAttention在加速推理的同时,能够有效保持模型精度。

🎯 应用场景

INT-FlashAttention可广泛应用于各种需要加速大型语言模型推理的场景,例如移动设备、边缘计算和云计算。通过降低计算复杂度和内存占用,INT-FlashAttention使得在资源受限的设备上部署大型语言模型成为可能,并能显著提升云端推理服务的效率和吞吐量。

📄 摘要(原文)

As the foundation of large language models (LLMs), self-attention module faces the challenge of quadratic time and memory complexity with respect to sequence length. FlashAttention accelerates attention computation and reduces its memory usage by leveraging the GPU memory hierarchy. A promising research direction is to integrate FlashAttention with quantization methods. This paper introduces INT-FlashAttention, the first INT8 quantization architecture compatible with the forward workflow of FlashAttention, which significantly improves the inference speed of FlashAttention on Ampere GPUs. We implement our INT-FlashAttention prototype with fully INT8 activations and general matrix-multiplication (GEMM) kernels, making it the first attention operator with fully INT8 input. As a general token-level post-training quantization framework, INT-FlashAttention is also compatible with other data formats like INT4, etc. Experimental results show INT-FlashAttention achieves 72% faster inference speed and 82% smaller quantization error compared to standard FlashAttention with FP16 and FP8 data format.