DuoAttention: Efficient Long-Context LLM Inference with Retrieval and Streaming Heads

📄 arXiv: 2410.10819v1 📥 PDF

作者: Guangxuan Xiao, Jiaming Tang, Jingwei Zuo, Junxian Guo, Shang Yang, Haotian Tang, Yao Fu, Song Han

分类: cs.CL

发布日期: 2024-10-14

🔗 代码/项目: GITHUB


💡 一句话要点

DuoAttention:利用检索头和流式头实现高效长文本LLM推理

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 长文本LLM 注意力机制 KV缓存 推理优化 检索头 流式头 模型压缩 高效推理

📋 核心要点

  1. 长文本LLM推理面临内存和计算挑战,现有KV缓存剪枝方法难以兼顾效率和长文本能力。
  2. DuoAttention框架的核心思想是区分检索头和流式头,并对它们采用不同的KV缓存策略。
  3. 实验表明,DuoAttention显著降低了长文本推理的内存占用,并加速了解码和预填充过程,同时保持了较高的精度。

📝 摘要(中文)

部署长上下文大型语言模型(LLM)至关重要,但也带来了巨大的计算和内存挑战。跨所有注意力头的键和值(KV)状态缓存消耗大量内存。现有的KV缓存修剪方法要么损害LLM的长上下文能力,要么仅提供有限的效率提升。本文发现,只有一小部分注意力头(即检索头)对于处理长上下文至关重要,并且需要跨所有token的完全注意力。相比之下,所有其他主要关注最近token和注意力汇的头(称为流式头)不需要完全注意力。基于此,我们引入了DuoAttention框架,该框架仅将完整的KV缓存应用于检索头,而对流式头使用轻量级的、恒定长度的KV缓存,从而在不影响其长上下文能力的情况下,减少LLM的解码和预填充内存及延迟。DuoAttention使用轻量级的、基于优化的算法和合成数据来准确识别检索头。我们的方法显著降低了长上下文推理内存,MHA模型最多降低2.55倍,GQA模型最多降低1.67倍,同时将解码速度分别提高2.18倍和1.50倍,并将MHA和GQA模型的预填充速度分别提高1.73倍和1.63倍,与完全注意力相比,精度损失极小。值得注意的是,结合量化,DuoAttention能够在单个A100 GPU上实现具有330万上下文长度的Llama-3-8B解码。代码已在https://github.com/mit-han-lab/duo-attention中提供。

🔬 方法详解

问题定义:长文本LLM推理需要大量的KV缓存,导致内存占用高、推理速度慢。现有的KV缓存剪枝方法要么损害LLM的长文本处理能力,要么效率提升有限,无法满足实际应用需求。

核心思路:论文的核心思路是观察到并非所有注意力头都对长文本处理同等重要。一部分头(检索头)负责检索长距离依赖,需要完整的KV缓存;另一部分头(流式头)主要关注局部信息,可以使用轻量级的KV缓存。通过区分对待这两种头,可以在保证性能的同时降低内存占用。

技术框架:DuoAttention框架主要包含两个阶段:检索头识别和差异化KV缓存。首先,使用基于优化的算法和合成数据来识别检索头。然后,对检索头使用完整的KV缓存,而对流式头使用轻量级的、固定长度的KV缓存。在推理过程中,根据不同的头类型采用不同的注意力计算方式。

关键创新:DuoAttention的关键创新在于发现了注意力头在长文本处理中的差异性,并提出了差异化的KV缓存策略。与传统的KV缓存剪枝方法相比,DuoAttention能够更精确地保留对长文本处理至关重要的信息,从而在保证性能的同时实现更高的效率。

关键设计:检索头的识别算法是基于优化的,目标是找到一组头,使得在移除其他头的情况下,模型在长文本任务上的性能损失最小。轻量级KV缓存可以使用滑动窗口或循环缓存等技术实现。论文还可能涉及一些超参数的设置,例如检索头的数量、轻量级KV缓存的长度等。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,DuoAttention在MHA模型上降低了高达2.55倍的推理内存,加速了解码过程高达2.18倍,加速了预填充过程高达1.73倍。在GQA模型上,推理内存降低了高达1.67倍,解码速度提高了1.50倍,预填充速度提高了1.63倍。同时,精度损失与完全注意力相比极小。结合量化,DuoAttention使得Llama-3-8B能够在单个A100 GPU上处理330万上下文长度。

🎯 应用场景

DuoAttention可应用于各种需要处理长文本的场景,例如长文档摘要、机器翻译、代码生成、对话系统等。该技术可以降低长文本LLM的部署成本,使其能够在资源受限的设备上运行,并提高推理速度,从而改善用户体验。未来,该技术有望推动长文本LLM在更多领域的应用。

📄 摘要(原文)

Deploying long-context large language models (LLMs) is essential but poses significant computational and memory challenges. Caching all Key and Value (KV) states across all attention heads consumes substantial memory. Existing KV cache pruning methods either damage the long-context capabilities of LLMs or offer only limited efficiency improvements. In this paper, we identify that only a fraction of attention heads, a.k.a, Retrieval Heads, are critical for processing long contexts and require full attention across all tokens. In contrast, all other heads, which primarily focus on recent tokens and attention sinks--referred to as Streaming Heads--do not require full attention. Based on this insight, we introduce DuoAttention, a framework that only applies a full KV cache to retrieval heads while using a light-weight, constant-length KV cache for streaming heads, which reduces both LLM's decoding and pre-filling memory and latency without compromising its long-context abilities. DuoAttention uses a lightweight, optimization-based algorithm with synthetic data to identify retrieval heads accurately. Our method significantly reduces long-context inference memory by up to 2.55x for MHA and 1.67x for GQA models while speeding up decoding by up to 2.18x and 1.50x and accelerating pre-filling by up to 1.73x and 1.63x for MHA and GQA models, respectively, with minimal accuracy loss compared to full attention. Notably, combined with quantization, DuoAttention enables Llama-3-8B decoding with 3.3 million context length on a single A100 GPU. Code is provided in https://github.com/mit-han-lab/duo-attention.