Tiled Flash Linear Attention: More Efficient Linear RNN and xLSTM Kernels
作者: Maximilian Beck, Korbinian Pöppel, Phillip Lippe, Sepp Hochreiter
分类: cs.LG, cs.AI
发布日期: 2025-03-18 (更新: 2025-12-28)
备注: Accepted at NeurIPS 2025. Code available at: https://github.com/NX-AI/mlstm_kernels
💡 一句话要点
提出Tiled Flash线性注意力(TFLA),加速线性RNN和xLSTM内核,提升长序列建模效率。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 线性RNN 长序列建模 Flash Attention xLSTM 序列并行化 高效计算 内存优化
📋 核心要点
- 现有Flash线性注意力(FLA)受限于chunk size,导致大量中间状态需在GPU内存中物化,降低算术强度并增加内存消耗。
- Tiled Flash线性注意力(TFLA)通过在chunk内引入序列并行化,支持任意大chunk size,提升算术强度,降低内存IO成本。
- 实验表明,基于TFLA的mLSTM内核超越Flash Attention、Linear Attention和Mamba,为长上下文序列建模提供更优方案。
📝 摘要(中文)
本文提出了一种用于线性RNN的新型内核算法——Tiled Flash Linear Attention (TFLA)。线性RNN及其门控机制在语言建模方面表现出与Transformer相当的性能。尽管线性RNN在序列长度上的线性计算复杂度在理论上优于Transformer,但要实际发挥这一优势,需要优化的自定义内核,因为Transformer依赖于高效的Flash Attention内核。TFLA通过在每个chunk内引入额外的序列并行化层级,实现了任意大的chunk size和高算术强度。首先,我们将TFLA应用于具有矩阵记忆的xLSTM,即mLSTM。其次,我们提出了一种具有sigmoid输入门和减少计算量的mLSTM变体,以在相同的语言建模性能下实现更快的内核运行时间。速度基准测试表明,基于TFLA的新mLSTM内核优于高度优化的Flash Attention、Linear Attention和Mamba内核,为高效的长上下文序列建模原语设定了新的state-of-the-art。
🔬 方法详解
问题定义:论文旨在解决线性RNN在长序列建模中,由于内存限制和算术强度不足,无法充分发挥其线性计算复杂度的优势的问题。现有方法如Flash Linear Attention (FLA) 虽然利用了chunkwise并行,但chunk size受限,导致大量中间状态需要存储在GPU内存中,造成高内存消耗和IO成本,尤其是在长上下文预训练中。
核心思路:论文的核心思路是在Flash Linear Attention的基础上,引入额外的序列并行化层级,即在每个chunk内部进行并行计算,从而允许使用任意大的chunk size。通过增大chunk size,可以减少中间状态的物化,提高算术强度,并降低内存IO成本。这种“分块再分块”的策略是TFLA的关键。
技术框架:TFLA的核心在于对输入序列进行分块,然后在每个块内进行进一步的并行化处理。具体流程如下: 1. 将输入序列划分为多个chunk。 2. 在每个chunk内部,再次进行序列并行化,将chunk划分为更小的tile。 3. 在每个tile上进行局部计算。 4. 通过并行规约操作,将tile的计算结果合并为chunk的最终结果。 5. 将所有chunk的结果组合起来,得到整个序列的输出。
关键创新:TFLA最重要的技术创新点在于引入了chunk内部的序列并行化,从而突破了Flash Linear Attention的chunk size限制。这使得TFLA能够充分利用GPU的并行计算能力,提高算术强度,并降低内存IO成本。此外,论文还提出了一种具有sigmoid输入门和减少计算量的mLSTM变体,进一步提升了内核的运行速度。
关键设计:TFLA的关键设计包括: 1. Chunk size的选择:更大的chunk size可以提高算术强度,但也会增加计算复杂度。需要根据GPU的内存容量和计算能力进行权衡。 2. Tile size的选择:tile size的选择也会影响并行计算的效率。需要根据GPU的架构进行优化。 3. 并行规约操作的设计:并行规约操作的效率直接影响TFLA的整体性能。需要选择高效的规约算法。
🖼️ 关键图片
📊 实验亮点
实验结果表明,基于TFLA的新mLSTM内核在速度上显著优于高度优化的Flash Attention、Linear Attention和Mamba内核,为高效的长上下文序列建模原语设定了新的state-of-the-art。具体性能数据未在摘要中给出,但强调了其超越现有方法的优势。
🎯 应用场景
TFLA及其优化的mLSTM内核可广泛应用于长文本建模、语音识别、视频处理等需要处理长序列数据的领域。其高效的计算和低内存占用特性,使其在资源受限的设备上部署大型语言模型成为可能,并加速长上下文预训练。
📄 摘要(原文)
Linear RNNs with gating recently demonstrated competitive performance compared to Transformers in language modeling. Although their linear compute scaling in sequence length offers theoretical runtime advantages over Transformers, realizing these benefits in practice requires optimized custom kernels, as Transformers rely on the highly efficient Flash Attention kernels (Dao, 2024). Leveraging the chunkwise-parallel formulation of linear RNNs, Flash Linear Attention (FLA) (Yang & Zhang, 2024) shows that linear RNN kernels are faster than Flash Attention, by parallelizing over chunks of the input sequence. However, since the chunk size of FLA is limited, many intermediate states must be materialized in GPU memory. This leads to low arithmetic intensity and causes high memory consumption and IO cost, especially for long-context pre-training. In this work, we present Tiled Flash Linear Attention (TFLA), a novel kernel algorithm for linear RNNs, that enables arbitrary large chunk sizes and high arithmetic intensity by introducing an additional level of sequence parallelization within each chunk. First, we apply TFLA to the xLSTM with matrix memory, the mLSTM (Beck et al., 2024). Second, we propose an mLSTM variant with sigmoid input gate and reduced computation for even faster kernel runtimes at equal language modeling performance. In our speed benchmarks, we show that our new mLSTM kernels based on TFLA outperform highly optimized Flash Attention, Linear Attention and Mamba kernels, setting a new state of the art for efficient long-context sequence modeling primitives.