LASP-2: Rethinking Sequence Parallelism for Linear Attention and Its Hybrid

📄 arXiv: 2502.07563v1 📥 PDF

作者: Weigao Sun, Disen Lan, Yiran Zhong, Xiaoye Qu, Yu Cheng

分类: cs.LG, cs.AI, cs.CL

发布日期: 2025-02-11

备注: Technical report, 17 pages

🔗 代码/项目: GITHUB


💡 一句话要点

LASP-2:重新设计线性注意力序列并行,提升超长序列训练效率

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 线性注意力 序列并行 分布式训练 超长序列 通信优化 Transformer模型 LASP-2 AllGather

📋 核心要点

  1. 现有序列并行方法在处理线性注意力模型时,未能充分优化其计算特性或通信策略,导致并行效率受限。
  2. LASP-2通过重新设计通信流程,减少了通信量,并提升了计算并行度,从而加速了线性注意力模型的训练。
  3. 实验表明,LASP-2在训练速度上显著优于现有方法,尤其是在超长序列和大规模GPU集群上。

📝 摘要(中文)

线性注意力等线性序列建模方法在训练时具有线性时间复杂度,推理时具有恒定内存占用等优势。然而,现有的序列并行(SP)方法要么没有针对线性注意力的“右积优先”特性进行优化,要么使用环形通信策略,导致计算并行度较低,限制了其在分布式系统中对更长序列的可扩展性。本文提出了LASP-2,一种新的SP方法,旨在提高训练具有超长输入序列的线性注意力Transformer模型时的通信和计算并行度。与之前的LASP相比,LASP-2重新思考了线性注意力层SP的最小通信需求,并重新组织了LASP的整体通信-计算工作流程。通过这种方式,只需要对中间内存状态进行一次AllGather集体通信,其大小与序列长度无关,从而显著提高了通信和计算并行度及其重叠。此外,通过将类似通信重新设计应用于标准注意力模块,我们将LASP-2扩展到LASP-2H,为混合模型(混合线性和标准注意力层)提供了一种有效的SP解决方案。在线性Llama3模型(一种用线性注意力代替标准注意力的Llama3变体)上的评估表明了LASP-2和LASP-2H的有效性。具体而言,在64个GPU上,序列长度为2048K时,LASP-2的训练速度比LASP提高了15.2%,比Ring Attention提高了36.6%。代码已作为https://github.com/OpenSparseLLMs/Linear-MoE 的一部分发布。

🔬 方法详解

问题定义:现有序列并行方法在训练具有超长序列的线性注意力模型时存在瓶颈。具体来说,它们要么没有针对线性注意力的“右积优先”特性进行优化,要么采用环形通信策略,导致计算并行度较低,限制了模型在分布式系统中的可扩展性。这些方法在高吞吐量和低延迟方面存在不足,无法充分利用分布式计算资源。

核心思路:LASP-2的核心思路是重新思考线性注意力层序列并行的最小通信需求,并重新组织通信-计算工作流程。通过减少通信量,并优化计算流程,从而提高整体的训练效率。关键在于利用线性注意力的特性,将通信量与序列长度解耦,从而实现更好的可扩展性。

技术框架:LASP-2的核心在于对线性注意力层的序列并行策略进行优化。它主要包含以下几个阶段:数据划分、局部计算、全局通信和结果聚合。与传统方法不同,LASP-2通过精心设计的通信模式,减少了全局通信的次数和数据量。对于混合模型,LASP-2H则将类似的通信优化策略扩展到标准注意力模块。

关键创新:LASP-2最重要的技术创新点在于其通信策略的重新设计。它将中间内存状态的AllGather通信量与序列长度解耦,使得通信量不再随序列长度线性增长。这与传统的序列并行方法形成了本质区别,后者通常需要进行多次通信,且通信量与序列长度成正比。

关键设计:LASP-2的关键设计在于如何最小化通信量,同时保证计算的正确性。这涉及到对线性注意力计算过程的深入理解,以及对分布式通信原语的巧妙运用。具体的技术细节包括:选择合适的AllGather策略、优化数据布局、以及设计高效的计算kernel。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,LASP-2在训练Linear-Llama3模型时,相比于LASP和Ring Attention,在训练速度上分别提升了15.2%和36.6%(在64个GPU上,序列长度为2048K)。这表明LASP-2在处理超长序列时具有显著的优势。

🎯 应用场景

LASP-2及其扩展LASP-2H可应用于训练具有超长序列的线性注意力Transformer模型,例如用于处理超长文本、基因组序列或时间序列数据。该方法能够显著提高训练效率,降低计算成本,从而加速相关领域的研究和应用。

📄 摘要(原文)

Linear sequence modeling approaches, such as linear attention, provide advantages like linear-time training and constant-memory inference over sequence lengths. However, existing sequence parallelism (SP) methods are either not optimized for the right-product-first feature of linear attention or use a ring-style communication strategy, which results in lower computation parallelism, limits their scalability for longer sequences in distributed systems. In this paper, we introduce LASP-2, a new SP method to enhance both communication and computation parallelism when training linear attention transformer models with very-long input sequences. Compared to previous work LASP, LASP-2 rethinks the minimal communication requirement for SP on linear attention layers, reorganizes the whole communication-computation workflow of LASP. In this way, only one single AllGather collective communication is needed on intermediate memory states, whose sizes are independent of the sequence length, leading to significant improvements of both communication and computation parallelism, as well as their overlap. Additionally, we extend LASP-2 to LASP-2H by applying similar communication redesign to standard attention modules, offering an efficient SP solution for hybrid models that blend linear and standard attention layers. Our evaluation on a Linear-Llama3 model, a variant of Llama3 with linear attention replacing standard attention, demonstrates the effectiveness of LASP-2 and LASP-2H. Specifically, LASP-2 achieves training speed improvements of 15.2% over LASP and 36.6% over Ring Attention, with a sequence length of 2048K across 64 GPUs. The Code is released as a part of: https://github.com/OpenSparseLLMs/Linear-MoE.