MInference 1.0: Accelerating Pre-filling for Long-Context LLMs via Dynamic Sparse Attention

📄 arXiv: 2407.02490v2 📥 PDF

作者: Huiqiang Jiang, Yucheng Li, Chengruidong Zhang, Qianhui Wu, Xufang Luo, Surin Ahn, Zhenhua Han, Amir H. Abdi, Dongsheng Li, Chin-Yew Lin, Yuqing Yang, Lili Qiu

分类: cs.CL, cs.LG

发布日期: 2024-07-02 (更新: 2024-10-30)

备注: Accepted at NeurIPS 2024 (Spotlight)


💡 一句话要点

MInference:通过动态稀疏注意力加速长文本LLM的预填充阶段

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 长文本LLM 稀疏注意力 预填充加速 动态索引 GPU优化

📋 核心要点

  1. 长文本LLM推理的计算挑战,特别是预填充阶段的二次复杂度,限制了其广泛应用。
  2. MInference通过识别并利用长文本注意力矩阵中的稀疏模式,动态构建稀疏索引,加速计算。
  3. 实验表明,MInference在保持准确性的前提下,显著降低了长文本LLM预填充阶段的推理延迟,最高可达10倍。

📝 摘要(中文)

本文提出了一种名为MInference(Milliontokens Inference)的稀疏计算方法,旨在加速长序列处理中大型语言模型(LLM)的预填充阶段。针对长文本注意力矩阵中存在的A形、垂直斜线和块稀疏三种独特模式,MInference离线确定每个注意力头的最佳模式,并在推理期间基于分配的模式动态构建稀疏索引。通过优化的GPU内核执行高效的稀疏注意力计算,显著降低长文本LLM预填充阶段的延迟。该技术可直接应用于现有LLM,无需修改预训练设置或进行额外的微调。在InfiniteBench、RULER、PG-19和Needle In A Haystack等下游任务以及LLaMA-3-1M、GLM4-1M、Yi-200K、Phi-3-128K和Qwen2-128K等模型上的评估表明,MInference在A100上有效降低了预填充的推理延迟,最高可达10倍,同时保持了准确性。

🔬 方法详解

问题定义:论文旨在解决长文本LLM推理中预填充阶段计算量过大的问题。现有方法在应用于长文本LLM时,往往难以在保持准确性的同时保证效率,导致推理速度慢,成本高昂。

核心思路:论文的核心思路是利用长文本注意力矩阵的稀疏性。通过观察发现,长文本的注意力矩阵存在特定的稀疏模式(A形、垂直斜线、块稀疏)。针对不同的注意力头,选择最合适的稀疏模式,并只计算非零元素,从而减少计算量。

技术框架:MInference的整体框架包含以下几个主要步骤:1) 离线分析:对模型的注意力矩阵进行离线分析,确定每个注意力头最适合的稀疏模式。2) 动态索引构建:在推理过程中,根据离线分析的结果,为每个注意力头动态构建稀疏索引。3) 稀疏注意力计算:利用优化的GPU内核,根据稀疏索引进行高效的稀疏注意力计算。

关键创新:MInference的关键创新在于动态稀疏注意力计算。与静态稀疏方法不同,MInference能够根据不同的注意力头自适应地选择最佳稀疏模式,从而更好地利用注意力矩阵的稀疏性。此外,MInference无需修改预训练模型或进行额外的微调,可以直接应用于现有的LLM。

关键设计:MInference的关键设计包括:1) 稀疏模式的选择:论文定义了三种稀疏模式(A形、垂直斜线、块稀疏),并提出了一种离线分析方法来确定每个注意力头最适合的模式。2) 稀疏索引的构建:论文设计了一种高效的稀疏索引结构,用于存储非零元素的位置信息。3) GPU内核优化:论文针对稀疏注意力计算,优化了GPU内核,提高了计算效率。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,MInference在多种长文本任务和模型上均取得了显著的性能提升。例如,在A100 GPU上,MInference可以将LLaMA-3-1M、GLM4-1M、Yi-200K、Phi-3-128K和Qwen2-128K等模型的预填充推理延迟降低高达10倍,同时保持了与原始模型相当的准确性。这些结果验证了MInference在加速长文本LLM推理方面的有效性。

🎯 应用场景

MInference可广泛应用于需要处理长文本的LLM推理场景,例如长文档摘要、代码生成、对话系统等。通过降低推理延迟和计算成本,MInference有助于推动长文本LLM的实际应用,并为用户提供更流畅、高效的交互体验。该技术还有潜力应用于其他类型的深度学习模型,以加速其推理过程。

📄 摘要(原文)

The computational challenges of Large Language Model (LLM) inference remain a significant barrier to their widespread deployment, especially as prompt lengths continue to increase. Due to the quadratic complexity of the attention computation, it takes 30 minutes for an 8B LLM to process a prompt of 1M tokens (i.e., the pre-filling stage) on a single A100 GPU. Existing methods for speeding up prefilling often fail to maintain acceptable accuracy or efficiency when applied to long-context LLMs. To address this gap, we introduce MInference (Milliontokens Inference), a sparse calculation method designed to accelerate pre-filling of long-sequence processing. Specifically, we identify three unique patterns in long-context attention matrices-the A-shape, Vertical-Slash, and Block-Sparsethat can be leveraged for efficient sparse computation on GPUs. We determine the optimal pattern for each attention head offline and dynamically build sparse indices based on the assigned pattern during inference. With the pattern and sparse indices, we perform efficient sparse attention calculations via our optimized GPU kernels to significantly reduce the latency in the pre-filling stage of long-context LLMs. Our proposed technique can be directly applied to existing LLMs without any modifications to the pre-training setup or additional fine-tuning. By evaluating on a wide range of downstream tasks, including InfiniteBench, RULER, PG-19, and Needle In A Haystack, and models including LLaMA-3-1M, GLM4-1M, Yi-200K, Phi-3-128K, and Qwen2-128K, we demonstrate that MInference effectively reduces inference latency by up to 10x for pre-filling on an A100, while maintaining accuracy. Our code is available at https://aka.ms/MInference.