SparQ Attention: Bandwidth-Efficient LLM Inference

📄 arXiv: 2312.04985v6 📥 PDF

作者: Luka Ribar, Ivan Chelombiev, Luke Hudlass-Galley, Charlie Blake, Carlo Luschi, Douglas Orr

分类: cs.LG

发布日期: 2023-12-08 (更新: 2024-09-04)


💡 一句话要点

SparQ Attention:通过选择性历史缓存,提升LLM推理带宽效率。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 注意力机制 大型语言模型 推理加速 内存带宽优化 选择性访问

📋 核心要点

  1. LLM推理受限于数据传输,尤其是在处理长序列和大型批次时,内存带宽成为瓶颈。
  2. SparQ Attention通过选择性地从缓存中获取历史信息,更有效地利用内存带宽,从而提高推理吞吐量。
  3. SparQ Attention无需修改预训练或微调,即可直接应用于现有LLM,并在多种模型上实现了高达8倍的数据传输节省。

📝 摘要(中文)

大型语言模型(LLM)推理的计算困难仍然是其广泛部署的重大障碍。许多应用需要支持长输入序列并在大型批次中处理它们,这通常导致token生成受限于数据传输。因此,我们引入了SparQ Attention,这是一种通过在注意力层中更有效地利用内存带宽来提高LLM推理吞吐量的技术,通过选择性地获取缓存的历史记录来实现。我们提出的技术可以直接应用于现成的LLM进行推理,而无需修改预训练设置或进行额外的微调。通过在各种下游任务上评估Llama 2和3、Mistral、Gemma和Pythia模型,我们表明SparQ Attention在不显著降低准确性的情况下,最多可节省8倍的注意力数据传输。

🔬 方法详解

问题定义:现有LLM推理过程中,注意力机制需要频繁访问内存中的历史状态(KV Cache),尤其是在处理长序列时,大量的内存读写操作成为性能瓶颈,限制了推理速度和吞吐量。现有方法通常采用模型压缩、量化等手段,但这些方法可能会牺牲模型精度。

核心思路:SparQ Attention的核心在于选择性地访问历史状态,而不是每次都访问全部历史信息。通过某种策略,只保留对当前token生成贡献最大的历史token的表示,从而减少内存访问量,提高带宽利用率。这种选择性访问的设计目标是在保证模型精度的前提下,尽可能减少数据传输量。

技术框架:SparQ Attention可以作为一个独立的模块插入到现有的Transformer架构的注意力层中。整体流程包括:1)计算Query和Key之间的相似度;2)根据相似度或其他指标,选择一部分Key-Value对;3)使用选择后的Key-Value对进行注意力计算。该框架可以灵活地与不同的选择策略相结合。

关键创新:SparQ Attention的关键创新在于提出了一个通用的框架,允许在注意力计算过程中进行选择性的历史信息访问。与传统注意力机制不同,SparQ Attention不是无差别地使用所有历史信息,而是根据一定的策略进行筛选,从而减少了内存访问量。这种选择性访问的思想可以应用于不同的注意力变体和模型架构。

关键设计:论文中可能涉及的关键设计包括:1)选择策略:如何选择重要的Key-Value对?可以使用基于相似度的阈值方法,或者使用学习到的选择器。2)相似度度量:如何计算Query和Key之间的相似度?可以使用点积、余弦相似度等。3)阈值设置:如果使用基于相似度的阈值方法,如何设置合适的阈值?阈值过高会导致信息丢失,阈值过低则无法有效减少内存访问量。4)模型集成:如何将SparQ Attention集成到现有的LLM中?需要考虑兼容性和性能优化。

📊 实验亮点

论文在Llama 2和3、Mistral、Gemma和Pythia等多个LLM模型上进行了实验,并在各种下游任务上进行了评估。实验结果表明,SparQ Attention在不显著降低准确性的情况下,最多可节省8倍的注意力数据传输。这一结果表明SparQ Attention能够有效地提高LLM的推理效率,具有很强的实用价值。

🎯 应用场景

SparQ Attention具有广泛的应用前景,尤其是在资源受限的场景下,例如移动设备、边缘计算等。它可以显著提高LLM在这些设备上的推理速度,使得LLM能够更好地服务于各种应用,如智能助手、机器翻译、文本生成等。此外,SparQ Attention还可以应用于需要处理长序列的任务,例如文档摘要、对话系统等,提高处理效率。

📄 摘要(原文)

The computational difficulties of large language model (LLM) inference remain a significant obstacle to their widespread deployment. The need for many applications to support long input sequences and process them in large batches typically causes token-generation to be bottlenecked by data transfer. For this reason, we introduce SparQ Attention, a technique for increasing the inference throughput of LLMs by utilising memory bandwidth more efficiently within the attention layers, through selective fetching of the cached history. Our proposed technique can be applied directly to off-the-shelf LLMs during inference, without requiring any modification to the pre-training setup or additional fine-tuning. We show that SparQ Attention brings up to 8x savings in attention data transfers without substantial drops in accuracy, by evaluating Llama 2 and 3, Mistral, Gemma and Pythia models on a wide range of downstream tasks.