DASH: Deterministic Attention Scheduling for High-throughput Reproducible LLM Training

📄 arXiv: 2601.21824v1 📥 PDF

作者: Xinwei Qiang, Hongmin Chen, Shixuan Sun, Jingwen Leng, Xin Liu, Minyi Guo

分类: cs.LG, cs.DC

发布日期: 2026-01-29

🔗 代码/项目: GITHUB


💡 一句话要点

DASH:用于高吞吐可复现LLM训练的确定性注意力调度

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 确定性注意力 LLM训练 调度算法 反向传播优化 有向无环图

📋 核心要点

  1. 现有FlashAttention等注意力机制的确定性反向传播因梯度累积串行化导致显著的性能下降。
  2. DASH将确定性注意力反向传播建模为DAG调度问题,通过优化调度策略最小化关键路径长度。
  3. 实验表明,DASH能显著提升确定性注意力反向传播的吞吐量,最高可达1.28倍。

📝 摘要(中文)

确定性对于大型语言模型(LLM)训练的可复现性至关重要,但通常会带来巨大的性能代价。在广泛使用的注意力机制实现(如FlashAttention-3)中,确定性反向传播可能导致高达37.9%的吞吐量下降,这主要是因为梯度累积操作必须串行化以保证数值一致性。这种性能损失源于计算和梯度归约阶段的次优调度,导致硬件利用率显著降低。为了解决这个问题,我们将确定性注意力的反向传播建模为有向无环图(DAG)上的调度问题,并推导出最小化关键路径长度的调度方案。在此基础上,我们提出了DASH(用于高吞吐量的确定性注意力调度),它包含两种互补的调度策略:(i)降序Q-Tile迭代,一种反向查询块遍历,可减少因果注意力中的流水线停顿;(ii)移位调度,一种在我们DAG模型中理论上最优的调度,可减少完整和因果掩码的流水线停顿。在NVIDIA H800 GPU上的实验评估表明,DASH缩小了确定性注意力的性能差距。与基线相比,所提出的策略将注意力反向传播的吞吐量提高了高达1.28倍,显著提高了可复现LLM训练的效率。

🔬 方法详解

问题定义:论文旨在解决大型语言模型训练中,为了保证可复现性而采用确定性注意力机制时,反向传播过程性能显著下降的问题。现有方法,如FlashAttention-3的确定性版本,由于需要串行化梯度累积操作以保证数值一致性,导致硬件利用率不足,吞吐量大幅降低。

核心思路:论文的核心思路是将确定性注意力的反向传播过程建模为一个有向无环图(DAG)上的调度问题,通过寻找最优的调度策略来最小化关键路径长度,从而提高硬件利用率和吞吐量。核心在于找到一种既能保证确定性,又能最大程度并行化计算和梯度归约操作的调度方案。

技术框架:DASH包含两个主要的调度策略:降序Q-Tile迭代和移位调度。降序Q-Tile迭代针对因果注意力,通过反向遍历查询块来减少流水线停顿。移位调度则是一种理论上最优的调度方案,适用于完整和因果掩码,旨在减少流水线停顿。整体框架围绕DAG模型展开,通过分析依赖关系,寻找最优的执行顺序。

关键创新:论文的关键创新在于将确定性注意力的反向传播过程形式化为DAG调度问题,并提出了两种互补的调度策略。与现有方法不同,DASH不是简单地串行化梯度累积操作,而是通过精细的调度来最大程度地并行化计算和梯度归约,从而在保证确定性的前提下提升性能。

关键设计:降序Q-Tile迭代的关键在于反向遍历查询块,这允许更早地启动某些计算,从而减少流水线停顿。移位调度的具体实现细节(例如,如何确定最优的移位量)在论文中可能包含更详细的数学推导和算法描述。DAG模型的构建方式,以及如何根据DAG模型生成调度方案,也是关键的设计细节。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,DASH能够显著缩小确定性注意力机制的性能差距。在NVIDIA H800 GPU上,DASH将注意力反向传播的吞吐量提高了高达1.28倍,与基线方法相比,性能提升显著。这一结果表明DASH在提高可复现LLM训练效率方面具有重要价值。

🎯 应用场景

DASH的潜在应用领域包括需要高精度和可复现性的大型语言模型训练,例如金融、医疗等对结果一致性要求高的领域。通过提高确定性注意力机制的效率,DASH可以降低训练成本,加速模型迭代,并促进这些领域中LLM的更广泛应用。未来,DASH的调度思想可以推广到其他需要确定性计算的深度学习任务中。

📄 摘要(原文)

Determinism is indispensable for reproducibility in large language model (LLM) training, yet it often exacts a steep performance cost. In widely used attention implementations such as FlashAttention-3, the deterministic backward pass can incur up to a 37.9% throughput reduction relative to its non-deterministic counterpart, primarily because gradient accumulation operations must be serialized to guarantee numerical consistency. This performance loss stems from suboptimal scheduling of compute and gradient-reduction phases, leading to significant hardware underutilization. To address this challenge, we formulate the backward pass of deterministic attention as a scheduling problem on a Directed Acyclic Graph (DAG) and derive schedules that minimize the critical path length. Building on this formulation, we present DASH (Deterministic Attention Scheduling for High-Throughput), which encapsulates two complementary scheduling strategies: (i) Descending Q-Tile Iteration, a reversed query-block traversal that shrinks pipeline stalls in causal attention, and (ii) Shift Scheduling, a theoretically optimal schedule within our DAG model that reduces pipeline stalls for both full and causal masks. Our empirical evaluations on NVIDIA H800 GPUs demonstrate that DASH narrows the performance gap of deterministic attention. The proposed strategies improve the throughput of the attention backward pass by up to 1.28$\times$ compared to the baseline, significantly advancing the efficiency of reproducible LLM training. Our code is open-sourced at https://github.com/SJTU-Liquid/deterministic-FA3.