Fixed-Point RNNs: Interpolating from Diagonal to Dense
作者: Sajad Movahedi, Felix Sarnthein, Nicola Muca Cirone, Antonio Orvieto
分类: cs.LG
发布日期: 2025-03-13 (更新: 2025-10-24)
备注: NeurIPS 2025 (Spotlight)
💡 一句话要点
提出基于定点RNN的序列建模方法,在效率和表达性之间取得平衡
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 循环神经网络 状态空间模型 序列建模 不动点迭代 线性RNN
📋 核心要点
- 现有线性RNN和SSM模型依赖于对角序列混合,限制了其状态跟踪的表达能力。
- 论文提出将稠密线性RNN参数化为对角线性RNN的定点,实现表达性和效率的平衡。
- 实验表明,该方法在状态跟踪任务上达到SOTA,并在其他任务上保持竞争力。
📝 摘要(中文)
线性循环神经网络(RNNs)和状态空间模型(SSMs),如Mamba,已成为Transformer架构中softmax注意力机制的有希望的替代方案,作为序列混合层。然而,目前的模型并没有表现出RNNs完整的状态跟踪表达能力,因为它们依赖于通道式(即对角)序列混合。本文研究了一大类稠密线性RNNs的参数化方法,将其作为可并行化的对角线性RNNs的定点。由此产生的模型可以在固定参数数量的情况下,自然地在表达性和效率之间进行权衡,并在状态跟踪基准测试$A_5$和$S_5$上取得最先进的结果,同时在复制和其他任务上匹配性能。
🔬 方法详解
问题定义:现有线性RNN和状态空间模型(如Mamba)在序列建模中表现出潜力,但由于依赖通道独立的对角序列混合,其状态跟踪的表达能力受到限制。稠密RNN虽然具有更强的表达能力,但计算复杂度较高,难以并行化。因此,如何在保持效率的同时,提升线性RNN的状态跟踪能力是一个关键问题。
核心思路:论文的核心思想是将一个复杂的稠密线性RNN建模为一系列可并行化的对角线性RNN的“不动点”。这意味着通过迭代应用一个简单的对角RNN,最终收敛到一个更具表达能力的稠密RNN。这种方法允许模型在表达能力和计算效率之间进行权衡,通过控制迭代次数来调整模型的复杂度。
技术框架:该方法的核心在于寻找一个对角线性RNN,其迭代应用后的不动点与目标稠密线性RNN相匹配。具体而言,给定一个目标稠密RNN的转移矩阵A,论文旨在找到一个对角矩阵D,使得D的多次迭代应用(通过某种变换)能够逼近A。整个框架可以分为以下几个步骤:1) 定义目标稠密RNN的结构;2) 设计对角线性RNN的迭代更新规则;3) 通过优化算法寻找合适的对角矩阵D;4) 使用得到的对角RNN进行序列建模。
关键创新:该方法最重要的创新点在于将稠密RNN的参数化问题转化为寻找对角RNN不动点的问题。这种转化使得模型可以在保持并行化优势的同时,获得更强的表达能力。与传统的稠密RNN相比,该方法避免了直接对稠密矩阵进行操作,从而降低了计算复杂度。与对角RNN相比,该方法通过迭代更新,获得了更强的状态跟踪能力。
关键设计:论文的关键设计包括:1) 如何定义对角线性RNN的迭代更新规则,使其能够逼近目标稠密RNN;2) 如何选择合适的优化算法来寻找最优的对角矩阵D;3) 如何在实际的序列建模任务中应用该方法,例如,如何将该方法集成到Transformer架构中。具体的参数设置和网络结构细节在论文中进行了详细描述,例如,迭代次数的选择、损失函数的设计等。
🖼️ 关键图片
📊 实验亮点
实验结果表明,该方法在状态跟踪基准测试$A_5$和$S_5$上取得了最先进的结果,超过了现有的线性RNN和SSM模型。同时,在复制任务和其他序列建模任务上,该方法也表现出与现有模型相当的性能。这些结果表明,该方法在表达性和效率之间取得了良好的平衡。
🎯 应用场景
该研究成果可应用于各种序列建模任务,例如语音识别、自然语言处理、时间序列预测等。特别是在需要长程依赖建模的场景下,该方法有望提供更高效、更精确的解决方案。此外,该方法还可以作为一种通用的RNN参数化方法,用于设计更具表达能力的循环神经网络。
📄 摘要(原文)
Linear recurrent neural networks (RNNs) and state-space models (SSMs) such as Mamba have become promising alternatives to softmax-attention as sequence mixing layers in Transformer architectures. Current models, however, do not exhibit the full state-tracking expressivity of RNNs because they rely on channel-wise (i.e. diagonal) sequence mixing. In this paper, we investigate parameterizations of a large class of dense linear RNNs as fixed-points of parallelizable diagonal linear RNNs. The resulting models can naturally trade expressivity for efficiency at a fixed number of parameters and achieve state-of-the-art results on the state-tracking benchmarks $A_5$ and $S_5$, while matching performance on copying and other tasks.