Star Attention: Efficient LLM Inference over Long Sequences

📄 arXiv: 2411.17116v3 📥 PDF

作者: Shantanu Acharya, Fei Jia, Boris Ginsburg

分类: cs.CL, cs.AI, cs.LG

发布日期: 2024-11-26 (更新: 2025-05-30)

备注: Accepted at ICML 2025


💡 一句话要点

提出Star Attention,通过块稀疏注意力加速长序列LLM推理。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 长序列建模 注意力机制 大型语言模型 推理加速 块稀疏注意力

📋 核心要点

  1. 长序列LLM推理面临自注意力机制带来的二次复杂度挑战,导致计算成本高昂和速度慢。
  2. Star Attention采用两阶段块稀疏注意力,第一阶段并行局部注意力,第二阶段全局注意力,降低计算复杂度。
  3. 实验表明,Star Attention能显著减少内存需求和推理时间,最高达11倍,同时保持较高的准确率(97-100%)。

📝 摘要(中文)

基于Transformer的大型语言模型(LLM)在长序列上的推理由于自注意力机制的二次复杂度而成本高昂且速度缓慢。我们提出了Star Attention,一种两阶段块稀疏近似方法,通过在多个主机上分片注意力来提高计算效率,同时最大限度地减少通信开销。在第一阶段,上下文使用跨主机的块状局部注意力并行处理。在第二阶段,查询和响应token通过序列全局注意力关注所有先前缓存的token。Star Attention与大多数使用全局注意力训练的基于Transformer的LLM无缝集成,将内存需求和推理时间最多减少11倍,同时保持97-100%的准确率。

🔬 方法详解

问题定义:现有的大型语言模型在处理长序列时,自注意力机制的计算复杂度呈二次方增长,导致推理速度慢,内存消耗大,难以部署到资源受限的设备上。因此,如何降低长序列LLM推理的计算成本和内存需求是一个关键问题。

核心思路:Star Attention的核心思路是通过块稀疏注意力来近似全局注意力,从而降低计算复杂度。具体来说,它将注意力计算分为两个阶段:第一阶段是块状局部注意力,在每个主机上并行计算;第二阶段是序列全局注意力,将查询和响应token与所有先前缓存的token进行交互。

技术框架:Star Attention的整体框架包含两个主要阶段:1) 块状局部注意力阶段:输入序列被分成多个块,每个块由一个主机处理。在每个主机上,token只关注其所在块内的其他token,实现局部注意力。这个阶段可以并行执行,提高计算效率。2) 序列全局注意力阶段:查询和响应token需要访问整个序列的信息。因此,在这个阶段,查询和响应token会关注所有先前缓存的token,实现全局注意力。

关键创新:Star Attention的关键创新在于其两阶段的块稀疏注意力机制。与传统的全局注意力相比,它显著降低了计算复杂度,同时通过第二阶段的全局注意力,尽可能地保留了全局信息。与其他的稀疏注意力方法相比,Star Attention在分布式环境下具有更好的通信效率。

关键设计:Star Attention的关键设计包括:1) 块大小的选择:块大小的选择会影响局部注意力的计算量和全局信息的保留程度。需要根据具体的应用场景进行调整。2) 全局注意力范围的确定:全局注意力只需要关注查询和响应token,可以显著减少计算量。3) 与现有Transformer模型的兼容性:Star Attention可以无缝集成到大多数基于Transformer的LLM中,无需重新训练模型。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

Star Attention在长序列推理任务上表现出色,与全局注意力相比,内存需求和推理时间最多减少11倍,同时保持了97-100%的准确率。这表明Star Attention能够在显著降低计算成本的同时,保持较高的模型性能。该方法在多个基准测试上都取得了优异的结果。

🎯 应用场景

Star Attention适用于需要处理长序列的各种应用场景,例如长文本摘要、机器翻译、对话生成、代码生成等。它可以降低LLM的推理成本,使其能够部署到资源受限的设备上,并加速LLM的应用落地。此外,该方法还可以用于训练更长的序列,从而提高LLM的性能。

📄 摘要(原文)

Inference with Transformer-based Large Language Models (LLMs) on long sequences is both costly and slow due to the quadratic complexity of the self-attention mechanism. We introduce Star Attention, a two-phase block-sparse approximation that improves computational efficiency by sharding attention across multiple hosts while minimizing communication overhead. In the first phase, the context is processed using blockwise-local attention across hosts, in parallel. In the second phase, query and response tokens attend to all prior cached tokens through sequence-global attention. Star Attention integrates seamlessly with most Transformer-based LLMs trained with global attention, reducing memory requirements and inference time by up to 11x while preserving 97-100% of accuracy.