Linear Attention Sequence Parallelism

📄 arXiv: 2404.02882v3 📥 PDF

作者: Weigao Sun, Zhen Qin, Dong Li, Xuyang Shen, Yu Qiao, Yiran Zhong

分类: cs.LG, cs.CL

发布日期: 2024-04-03 (更新: 2025-05-16)

备注: Accepted by TMLR, 23 pages

🔗 代码/项目: GITHUB


💡 一句话要点

提出线性注意力序列并行方法以提升长序列处理效率

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 线性注意力 序列并行 计算效率 分布式训练 长序列处理

📋 核心要点

  1. 现有的序列并行方法未能充分利用线性注意力的右乘优先特性,导致通信效率低下。
  2. 本文提出线性注意力序列并行(LASP),通过设计高效的通信机制和优化计算流程来提升性能。
  3. 实验结果表明,LASP在128个GPU上支持的序列长度达到4096K,较现有方法提升了8倍。

📝 摘要(中文)

序列并行(SP)是处理超出单个设备内存限制的长序列的常用策略。然而,对于线性序列建模方法如线性注意力,现有的SP方法未能利用其右乘优先特性,导致通信效率和可用性低下。本文提出了线性注意力序列并行(LASP),旨在为基于线性注意力的变换器模型设计一种高效的SP方法。我们设计了一种高效的点对点环形通信机制,以利用线性注意力的右乘核技巧,显著降低通信开销。通过核融合和中间状态缓存,我们提高了LASP的计算效率,使其在GPU上实现硬件友好。此外,我们确保了序列级LASP与所有类型的批量级数据并行方法的兼容性,这对于在大型集群上进行长序列的分布式训练至关重要。我们还讨论了LASP在其他线性序列建模方法上的推广。对线性注意力模型进行了广泛实验,序列长度从2K到4096K不等。LASP在128个GPU上将序列长度扩展至4096K,是现有SP方法的8倍。

🔬 方法详解

问题定义:本文旨在解决现有序列并行方法在处理线性注意力模型时的通信效率低下问题,尤其是在长序列情况下的内存限制和计算瓶颈。

核心思路:通过引入线性注意力的右乘优先特性,设计高效的点对点环形通信机制,减少通信开销,并结合核融合和中间状态缓存来提升计算效率。

技术框架:LASP的整体架构包括数据分割、通信机制、计算优化和结果合并等模块。首先将输入序列分割为多个子序列,然后通过环形通信机制进行数据传输,最后在各个GPU上并行计算并合并结果。

关键创新:LASP的主要创新在于利用线性注意力的右乘核技巧,设计了高效的通信机制,显著降低了通信开销,与传统的SP方法相比,提升了整体性能和可用性。

关键设计:在实现LASP时,采用了核融合技术以减少计算冗余,并通过中间状态缓存来优化内存使用。此外,确保了与批量级数据并行方法的兼容性,以支持大规模分布式训练。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果显示,LASP在128个GPU上实现了4096K的序列长度处理,较现有序列并行方法提升了8倍,显著提高了长序列的处理能力和效率。

🎯 应用场景

该研究的潜在应用领域包括自然语言处理、计算机视觉和其他需要处理长序列数据的任务。LASP的高效性使其在大规模分布式训练中具有重要价值,能够支持更长的序列处理,推动相关领域的研究进展。

📄 摘要(原文)

Sequence parallelism (SP) serves as a prevalent strategy to handle long sequences that exceed the memory limit of a single device. However, for linear sequence modeling methods like linear attention, existing SP approaches do not take advantage of their right-product-first feature, resulting in sub-optimal communication efficiency and usability. In this paper, we introduce Linear Attention Sequence Parallelism (LASP), an efficient SP approach designed for linear attention-based transformer models. Specifically, we design an efficient point-to-point ring-style communication mechanism to leverage the right-product kernel trick of linear attention, which sharply decreases the communication overhead, comparing with existing SP methods. We enhance the computation efficiency of LASP by performing kernel fusion and intermediate state caching, making the implementation of LASP hardware-friendly on GPUs. Furthermore, we meticulously ensure the compatibility of sequence-level LASP with all types of batch-level data parallel methods, which is vital for distributed training on large clusters with very-long sequences. We also discuss the generalization of LASP on other linear sequence modeling methods. Extensive experiments on linear attention-based models are conducted with varying sequence lengths from 2K to 4096K. LASP scales sequence length up to 4096K on 128 GPUs, which is 8$\times$ longer than existing SP methods. Code is available at: https://github.com/OpenNLPLab/LASP.