Mixture of Sparse Attention: Content-Based Learnable Sparse Attention via Expert-Choice Routing

📄 arXiv: 2505.00315v1 📥 PDF

作者: Piotr Piękos, Róbert Csordás, Jürgen Schmidhuber

分类: cs.LG, cs.CL

发布日期: 2025-05-01


💡 一句话要点

提出MoSA:通过专家选择路由实现内容感知的可学习稀疏注意力机制,提升计算效率。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 稀疏注意力 混合专家模型 专家选择路由 长序列建模 计算效率

📋 核心要点

  1. 自注意力机制的计算复杂度是序列长度的平方,限制了其在长序列上的应用,现有稀疏注意力方法性能仍有不足。
  2. MoSA通过专家选择路由,为每个注意力头动态选择token,实现内容感知的稀疏注意力模式,降低计算复杂度。
  3. 实验表明,MoSA在相同计算预算下,困惑度优于密集基线,且训练速度更快,内存占用更少,KV缓存更小。

📝 摘要(中文)

大型语言模型的最新进展凸显了自注意力机制的二次方计算成本过高。尽管研究投入巨大,但亚二次方注意力方法在实践中仍然表现不佳。我们假设动态的、可学习的、基于内容的稀疏性可以带来更高效的注意力机制。我们提出了混合稀疏注意力(MoSA),这是一种受到混合专家(MoE)和专家选择路由启发的创新方法。MoSA为每个注意力头动态选择token,从而允许任意的稀疏注意力模式。通过从长度为T的序列中选择k个token,MoSA将每个注意力头的计算复杂度从O(T^2)降低到O(k^2 + T)。这使得在相同的计算预算内可以使用更多的头,从而实现更高的专业化。我们表明,在测试的稀疏注意力变体中,MoSA是唯一能够超越密集基线的方法,有时在相同的计算预算下,困惑度最多可提高27%。MoSA还可以减少与密集自注意力相比的资源使用。尽管使用了没有优化内核的torch实现,但困惑度匹配的MoSA模型在挂钟时间上更快,训练所需的内存更少,并且与密集transformer基线相比,KV缓存的大小大大减小。

🔬 方法详解

问题定义:论文旨在解决自注意力机制在处理长序列时计算复杂度过高的问题。现有稀疏注意力方法虽然降低了计算复杂度,但在实际应用中性能往往不如密集注意力机制,无法充分利用计算资源。

核心思路:论文的核心思路是引入内容感知的可学习稀疏性。通过模仿混合专家模型(MoE)的专家选择路由机制,MoSA允许每个注意力头动态地选择序列中的一部分token进行关注,从而打破了传统稀疏注意力机制中预定义的稀疏模式,使得模型能够根据输入内容自适应地学习更有效的稀疏模式。

技术框架:MoSA的整体框架是在Transformer架构的基础上,将传统的自注意力层替换为MoSA层。MoSA层包含以下几个关键步骤:1) 路由选择:使用一个路由网络(通常是一个小型神经网络)为每个注意力头选择最相关的k个token。路由网络的输入是query向量和所有token的key向量,输出是每个token被选择的概率或得分。2) 稀疏注意力计算:每个注意力头只关注被选中的k个token,计算复杂度从O(T^2)降低到O(k^2)。3) 结果聚合:将每个注意力头的输出进行聚合,得到最终的注意力输出。

关键创新:MoSA的关键创新在于其动态的、内容感知的稀疏模式。与预定义的稀疏模式相比,MoSA能够根据输入内容自适应地学习更有效的稀疏模式,从而在降低计算复杂度的同时,保持甚至提升模型的性能。此外,MoSA通过专家选择路由机制,使得每个注意力头可以专注于不同的token子集,从而实现更高的专业化。

关键设计:MoSA的关键设计包括:1) 路由网络的选择:可以使用不同的神经网络作为路由网络,例如简单的线性层或更复杂的MLP。2) 选择token的数量k:k的取值会影响模型的计算复杂度和性能,需要根据具体的任务和数据集进行调整。3) 路由损失:为了鼓励不同的注意力头关注不同的token子集,可以引入路由损失,例如KL散度损失或互信息损失。4) Top-k选择:路由网络输出每个token的得分后,通常使用Top-k选择算法来选择得分最高的k个token。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,MoSA在语言建模任务上优于传统的密集自注意力机制。在相同的计算预算下,MoSA可以将困惑度降低高达27%。此外,MoSA模型在训练速度上更快,内存占用更少,并且KV缓存的大小大大减小。即使使用未优化的torch实现,MoSA模型在挂钟时间上仍然优于密集基线。

🎯 应用场景

MoSA适用于需要处理长序列的各种自然语言处理任务,例如机器翻译、文本摘要、问答系统和语言建模。其降低计算复杂度和内存占用的特性,使其在资源受限的设备上部署大型语言模型成为可能。未来,MoSA可以扩展到其他领域,例如计算机视觉和语音识别,以处理高分辨率图像和长音频序列。

📄 摘要(原文)

Recent advances in large language models highlighted the excessive quadratic cost of self-attention. Despite the significant research efforts, subquadratic attention methods still suffer from inferior performance in practice. We hypothesize that dynamic, learned content-based sparsity can lead to more efficient attention mechanisms. We present Mixture of Sparse Attention (MoSA), a novel approach inspired by Mixture of Experts (MoE) with expert choice routing. MoSA dynamically selects tokens for each attention head, allowing arbitrary sparse attention patterns. By selecting $k$ tokens from a sequence of length $T$, MoSA reduces the computational complexity of each attention head from $O(T^2)$ to $O(k^2 + T)$. This enables using more heads within the same computational budget, allowing higher specialization. We show that among the tested sparse attention variants, MoSA is the only one that can outperform the dense baseline, sometimes with up to 27% better perplexity for an identical compute budget. MoSA can also reduce the resource usage compared to dense self-attention. Despite using torch implementation without an optimized kernel, perplexity-matched MoSA models are simultaneously faster in wall-clock time, require less memory for training, and drastically reduce the size of the KV-cache compared to the dense transformer baselines.