Gated Linear Attention Transformers with Hardware-Efficient Training
作者: Songlin Yang, Bailin Wang, Yikang Shen, Rameswar Panda, Yoon Kim
分类: cs.LG, cs.CL
发布日期: 2023-12-11 (更新: 2024-08-27)
备注: minor update
💡 一句话要点
提出硬件高效的门控线性注意力Transformer,提升训练速度和长序列泛化能力。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 线性注意力 Transformer 硬件高效 门控机制 长序列建模 语言建模 FLASHLINEARATTENTION
📋 核心要点
- 线性注意力虽然推理速度快,但性能通常不如softmax注意力,且现有实现缺乏I/O优化。
- 提出FLASHLINEARATTENTION算法,通过权衡内存移动和并行性,实现硬件高效的线性注意力。
- 引入门控机制,提出GLA Transformer,在语言建模和长序列泛化方面表现出竞争力,且训练速度优于Mamba。
📝 摘要(中文)
线性注意力Transformer具有高效的并行训练能力,并且可以被形式化为具有二维(矩阵值)隐藏状态的RNN,从而实现线性时间复杂度的推理。然而,线性注意力通常不如普通的softmax注意力。此外,当前线性注意力的实现缺乏I/O感知,因此比高度优化的softmax注意力实现慢。本文提出了一种硬件高效的线性注意力算法,它在内存移动和并行性之间进行权衡。由此产生的实现,被称为FLASHLINEARATTENTION,即使在短序列长度(例如,1K)上,也比FLASHATTENTION-2更快。然后,我们将此算法推广到具有数据相关门控的更具表现力的线性注意力变体。当用作Transformer中标准注意力层的替代品时,发现由此产生的门控线性注意力(GLA)Transformer在适度规模的语言建模实验中,与LLaMA架构的Transformer以及最近的线性时间推理基线(如RetNet和Mamba)相比,具有竞争力。GLA Transformer在长度泛化方面尤其有效,使在2K上训练的模型能够泛化到超过20K的序列,而不会出现明显的困惑度下降。在训练速度方面,GLA Transformer比类似大小的Mamba模型具有更高的吞吐量。
🔬 方法详解
问题定义:线性注意力Transformer虽然具有线性时间复杂度的推理优势,但其性能通常低于softmax注意力,并且现有实现没有充分考虑硬件I/O效率,导致实际运行速度较慢。因此,需要一种既能保持线性推理速度,又能提升性能和硬件效率的线性注意力机制。
核心思路:核心思路是通过优化内存访问模式和并行计算策略,设计一种硬件友好的线性注意力算法。此外,引入门控机制,增强模型的表达能力,从而提升整体性能。
技术框架:该研究主要包含两个部分:一是FLASHLINEARATTENTION算法的实现,旨在优化线性注意力的硬件效率;二是将FLASHLINEARATTENTION推广到门控线性注意力(GLA),并将其集成到Transformer架构中。整体流程包括:线性注意力计算优化 -> 门控机制引入 -> Transformer集成 -> 实验验证。
关键创新:关键创新在于FLASHLINEARATTENTION算法,它通过权衡内存移动和并行性,实现了比现有线性注意力实现更高的硬件效率。此外,门控机制的引入增强了模型的表达能力,使其在语言建模任务中表现出竞争力。
关键设计:FLASHLINEARATTENTION算法的关键设计在于优化了内存访问模式,减少了不必要的内存移动,并充分利用了硬件的并行计算能力。门控机制的具体实现细节(例如门控函数的选择、门控值的计算方式等)在论文中可能有所描述,但摘要中未明确提及。损失函数和网络结构方面,GLA Transformer主要沿用了Transformer的通用设计,关键在于将标准注意力层替换为GLA层。
📊 实验亮点
实验结果表明,FLASHLINEARATTENTION算法在短序列长度上已经优于FLASHATTENTION-2。GLA Transformer在语言建模任务中表现出与LLaMA和RetNet等模型相当的性能,并且在长序列泛化方面表现出色,能够在2K长度上训练的模型泛化到20K以上的序列,且训练吞吐量高于Mamba。
🎯 应用场景
该研究成果可应用于各种需要处理长序列数据的场景,例如自然语言处理、语音识别、视频分析等。尤其是在资源受限的设备上,硬件高效的线性注意力机制能够提供更快的推理速度和更低的能耗。未来,该技术有望推动大规模语言模型的部署和应用。
📄 摘要(原文)
Transformers with linear attention allow for efficient parallel training but can simultaneously be formulated as an RNN with 2D (matrix-valued) hidden states, thus enjoying linear-time inference complexity. However, linear attention generally underperforms ordinary softmax attention. Moreover, current implementations of linear attention lack I/O-awareness and are thus slower than highly optimized implementations of softmax attention. This work describes a hardware-efficient algorithm for linear attention that trades off memory movement against parallelizability. The resulting implementation, dubbed FLASHLINEARATTENTION, is faster than FLASHATTENTION-2 (Dao, 2023) as a standalone layer even on short sequence lengths (e.g., 1K). We then generalize this algorithm to a more expressive variant of linear attention with data-dependent gates. When used as a replacement for the standard attention layer in Transformers, the resulting gated linear attention (GLA) Transformer is found to perform competitively against the LLaMA-architecture Transformer (Touvron et al., 2023) as well recent linear-time-inference baselines such as RetNet (Sun et al., 2023a) and Mamba (Gu & Dao, 2023) on moderate-scale language modeling experiments. GLA Transformer is especially effective at length generalization, enabling a model trained on 2K to generalize to sequences longer than 20K without significant perplexity degradations. For training speed, the GLA Transformer has higher throughput than a similarly-sized Mamba model.