Block Sparse Flash Attention

📄 arXiv: 2512.07011v1 📥 PDF

作者: Daniel Ohayon, Itay Lamprecht, Itay Hubara, Israel Cohen, Daniel Soudry, Noam Elata

分类: cs.LG, cs.CL, cs.PF

发布日期: 2025-12-07

备注: 10 pages, 5 figures. Code: https://github.com/Danielohayon/Block-Sparse-Flash-Attention

🔗 代码/项目: GITHUB


💡 一句话要点

提出块稀疏Flash Attention加速长文本推理,保持模型质量。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 长文本处理 稀疏Attention FlashAttention 语言模型加速 推理优化

📋 核心要点

  1. Attention机制在处理长文本时面临二次复杂度挑战,成为大型语言模型的计算瓶颈。
  2. BSFA通过计算精确的query-key相似度,选择最重要的value块,从而实现稀疏计算。
  3. 实验表明,BSFA在加速推理的同时,保持甚至提升了模型准确率,优于现有稀疏Attention方法。

📝 摘要(中文)

现代大型语言模型越来越多地需要长上下文来进行推理和多文档任务,但Attention机制的二次复杂度造成了严重的计算瓶颈。我们提出了块稀疏FlashAttention(BSFA),这是一种即插即用的替代方案,可以在保持模型质量的同时加速长上下文推理。与在计算得分之前预测重要性的方法不同,BSFA计算精确的query-key相似度,以选择对于每个query最重要的top-k个value块。通过将每个块的最大得分与校准的阈值进行比较,我们跳过了大约50%的被剪枝块的计算和内存传输。我们的免训练方法只需要在小数据集上进行一次阈值校准,以学习每层和每头的attention得分分布。我们提供了一个CUDA内核实现,可以作为FlashAttention的即插即用替代品。在Llama-3.1-8B上,BSFA在真实世界的推理基准测试中实现了高达1.10倍的加速,在needle-in-a-haystack检索任务中实现了高达1.24倍的加速,同时保持了99%以上的基线准确率,某些配置甚至通过专注于最相关的内容来提高准确率,大大优于现有的稀疏attention方法。该实现可在https://github.com/Danielohayon/Block-Sparse-Flash-Attention获得。

🔬 方法详解

问题定义:论文旨在解决长文本场景下,标准Attention机制计算复杂度过高的问题。现有方法要么牺牲模型精度来预测重要性,要么计算效率提升有限,无法满足实际应用需求。

核心思路:核心思想是利用块稀疏性,只关注对每个query最重要的value块。通过精确计算query-key相似度,选择top-k个value块进行计算,从而减少计算量和内存传输。这种方法避免了预先预测重要性带来的精度损失。

技术框架:BSFA可以作为FlashAttention的即插即用替代品。整体流程包括:1) 计算query-key相似度;2) 对每个query,选择top-k个value块;3) 根据校准的阈值,跳过不重要的块的计算;4) 执行FlashAttention计算。关键在于阈值校准和块选择策略。

关键创新:最重要的创新在于,它不是预先预测重要性,而是通过计算精确的query-key相似度来动态选择重要的value块。此外,通过校准阈值,可以自适应地跳过不重要的块,进一步提高计算效率。与现有稀疏Attention方法相比,BSFA在精度和效率上都取得了更好的平衡。

关键设计:阈值校准是关键设计之一,通过在小数据集上学习每层和每头的attention得分分布,确定合适的阈值。另一个关键设计是块大小的选择,需要在计算效率和模型精度之间进行权衡。论文提供了一个CUDA内核实现,优化了内存访问和计算效率。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

在Llama-3.1-8B模型上,BSFA在真实世界的推理基准测试中实现了高达1.10倍的加速,在needle-in-a-haystack检索任务中实现了高达1.24倍的加速,同时保持了99%以上的基线准确率。某些配置甚至通过专注于最相关的内容来提高准确率,显著优于现有的稀疏attention方法。

🎯 应用场景

BSFA可广泛应用于需要处理长文本的场景,如文档总结、机器翻译、问答系统、代码生成等。通过降低计算成本,BSFA使得大型语言模型能够更高效地处理更长的上下文,从而提升模型在复杂任务中的表现。该技术还有助于降低部署成本,使更多用户能够使用高性能的语言模型。

📄 摘要(原文)

Modern large language models increasingly require long contexts for reasoning and multi-document tasks, but attention's quadratic complexity creates a severe computational bottleneck. We present Block-Sparse FlashAttention (BSFA), a drop-in replacement that accelerates long-context inference while preserving model quality. Unlike methods that predict importance before computing scores, BSFA computes exact query-key similarities to select the top-k most important value blocks for each query. By comparing per-block maximum scores against calibrated thresholds, we skip approximately 50% of the computation and memory transfers for pruned blocks. Our training-free approach requires only a one-time threshold calibration on a small dataset to learn the per-layer and per-head attention score distributions. We provide a CUDA kernel implementation that can be used as a drop-in replacement for FlashAttention. On Llama-3.1-8B, BSFA achieves up to 1.10x speedup on real-world reasoning benchmarks and up to 1.24x for needle-in-a-haystack retrieval tasks while maintaining above 99% baseline accuracy, with certain configurations even improving accuracy by focusing on the most relevant content, substantially outperforming existing sparse attention methods. The implementation is available at https://github.com/Danielohayon/Block-Sparse-Flash-Attention