DashAttention: Differentiable and Adaptive Sparse Hierarchical Attention

📄 arXiv: 2605.18753v1 📥 PDF

作者: Yuxiang Huang, Nuno M. T. Gonçalves, Federico Alvetreti, Lei Li, Xu Han, Edoardo M. Ponti, André F. T. Martins, Marcos V. Treviso

分类: cs.CL, cs.AI, cs.LG

发布日期: 2026-05-18

备注: Preprint


💡 一句话要点

提出DashAttention,一种可微自适应稀疏分层注意力机制,提升长文本建模效率。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 分层注意力 稀疏注意力 长文本建模 自适应稀疏性 α-entmax 大型语言模型 GPU加速

📋 核心要点

  1. 现有分层注意力方法的top-k选择策略限制了模型对不同查询自适应选择相关token的能力,并阻断了梯度流动。
  2. DashAttention利用可微的自适应稀疏$α$-entmax变换,动态选择相关KV块,并为后续softmax注意力提供先验信息。
  3. 实验表明,DashAttention在保持竞争力的同时,在高稀疏度下优于现有方法,并实现了优于FlashAttention-3的推理速度。

📝 摘要(中文)

本文提出了一种可微自适应稀疏分层注意力机制DashAttention,旨在解决现有分层注意力方法(如NSA和InfLLMv2)中存在的top-k选择策略的局限性。现有方法基于粗略注意力分数选择top-k个相关的键值(KV)块,然后对选定的token应用细粒度的softmax注意力。然而,top-k操作假定任何查询的相关token数量是固定的,并且阻断了稀疏和密集阶段之间的梯度流动。DashAttention利用自适应稀疏的$α$-entmax变换,根据当前查询自适应地选择可变数量的块,并在第一阶段提供先验信息,用于第二阶段的softmax注意力,从而保持整个层次结构完全可微。实验表明,DashAttention具有非分散性,从而提高了长上下文建模能力。在大型语言模型(LLM)上的实验表明,DashAttention在75%的稀疏度下实现了与完整注意力相当的精度,并且比NSA和InfLLMv2具有更好的Pareto前沿,尤其是在高稀疏度情况下。此外,本文还提供了一个高效的、GPU感知的Triton实现,在推理时实现了比FlashAttention-3更高的加速。总而言之,DashAttention提供了一种经济高效的策略来建模长上下文。

🔬 方法详解

问题定义:现有分层注意力方法,如NSA和InfLLMv2,采用top-k选择策略,即基于粗略注意力分数选择固定数量的KV块。这种方法的痛点在于:一是假设所有查询都需要相同数量的token,缺乏灵活性;二是top-k操作不可微,导致梯度无法在稀疏和密集阶段之间流动,影响模型训练效果。

核心思路:DashAttention的核心思路是引入可微的自适应稀疏性。具体而言,使用$α$-entmax变换替代传统的top-k选择,使得模型能够根据当前查询动态地选择不同数量的KV块。这种自适应选择机制允许模型更好地捕捉不同查询的需求,同时保持整个过程的可微性,从而实现端到端的优化。

技术框架:DashAttention是一个两阶段的分层注意力机制。第一阶段,使用$α$-entmax变换对KV块进行稀疏化选择,得到一个稀疏的注意力权重分布。第二阶段,基于第一阶段的稀疏注意力权重,对选定的KV块进行细粒度的softmax注意力计算。整个框架是端到端可微的,允许梯度在两个阶段之间自由流动。

关键创新:DashAttention的关键创新在于使用$α$-entmax变换实现自适应稀疏注意力。与传统的top-k选择相比,$α$-entmax变换具有以下优势:一是可微性,允许梯度在整个网络中传播;二是自适应性,能够根据当前查询动态地选择不同数量的KV块;三是非分散性,有助于提高长上下文建模能力。

关键设计:$α$-entmax变换是DashAttention的关键组成部分。$α$是一个超参数,控制稀疏度。当$α=1$时,$α$-entmax退化为softmax;当$α=2$时,$α$-entmax产生一个稀疏的概率分布,只有少数几个元素具有非零值。论文中使用了可学习的$α$值,允许模型根据数据自适应地调整稀疏度。此外,论文还提供了一个高效的GPU-aware Triton实现,以加速DashAttention的计算。

📊 实验亮点

实验结果表明,DashAttention在75%的稀疏度下实现了与完整注意力相当的精度,并且比NSA和InfLLMv2具有更好的Pareto前沿,尤其是在高稀疏度情况下。此外,DashAttention的Triton实现实现了比FlashAttention-3更高的推理速度,证明了其高效性。

🎯 应用场景

DashAttention适用于需要处理长序列数据的各种应用场景,例如长文本摘要、机器翻译、语音识别、视频理解等。通过降低计算复杂度,DashAttention使得在资源受限的环境中部署大型语言模型成为可能,并有望推动长上下文建模技术的发展。

📄 摘要(原文)

Current hierarchical attention methods, such as NSA and InfLLMv2, select the top-k relevant key-value (KV) blocks based on coarse attention scores and subsequently apply fine-grained softmax attention on the selected tokens. However, the top-k operation assumes the number of relevant tokens for any query is fixed and it precludes the gradient flow between the sparse and dense stages. In this work, we propose DashAttention (Differentiable and Adaptive Sparse Hierarchical Attention), which leverages the adaptively sparse $α$-entmax transformation to select a variable number of blocks according to the current query in the first stage. This in turn provides a prior for the second-stage softmax attention, keeping the entire hierarchy fully differentiable. Contrary to other hierarchical attention methods, we show that DashAttention is non-dispersive, translating to better long-context modeling ability. Experiments with large language models (LLMs) show that DashAttention achieves comparable accuracy as full attention with 75% sparsity and a better Pareto frontier than NSA and InfLLMv2, especially in high-sparsity regimes. We also provide an efficient, GPU-aware implementation of DashAttention in Triton, which achieves a speedup of up to over FlashAttention-3 at inference time. Overall, DashAttention offers a cost-effective strategy to model long contexts.