RecurFormer: Not All Transformer Heads Need Self-Attention
作者: Ruiqing Yan, Linghan Zheng, Xingbo Du, Han Zou, Yufeng Guo, Jianfei Yang
分类: cs.CL, cs.AI, cs.LG
发布日期: 2024-10-10
💡 一句话要点
RecurFormer:用线性循环网络替换Transformer中冗余自注意力头,提升长文本推理效率。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: Transformer 线性循环网络 Mamba 注意力机制 长文本处理 推理效率 模型压缩
📋 核心要点
- Transformer模型在处理长文本时,注意力机制导致计算和内存开销巨大,限制了推理效率。
- RecurFormer通过识别并替换Transformer中负责局部依赖的注意力头为更高效的线性循环网络(Mamba),降低计算成本。
- 实验表明,RecurFormer在保持模型性能的同时,显著提升了推理效率,尤其是在处理长文本时。
📝 摘要(中文)
基于Transformer的大型语言模型(LLMs)在建模复杂语言模式方面表现出色,但在推理过程中面临巨大的计算成本,特别是对于长输入,这是由于注意力机制的内存开销。我们观察到,某些注意力头的分布表现出一种“近因感知”特性,即注意力权重集中在查询token附近的token上,专注于局部和短程依赖。基于此,我们提出了RecurFormer,一种新颖的架构,用线性循环神经网络(RNNs),特别是Mamba架构,替换这些注意力头。这种替换减少了缓存大小,而无需逐出token,从而保持了生成质量。RecurFormer保留了通过剩余注意力头建模远程依赖的能力,并允许通过持续训练重用预训练的Transformer LLMs权重。实验表明,RecurFormer在显著提高推理效率的同时,匹配了原始模型的性能。我们的方法为基于Transformer的LLMs推理的计算挑战提供了一个实用的解决方案,使其对涉及长输入的任务极具吸引力。
🔬 方法详解
问题定义:Transformer模型在处理长文本时,自注意力机制的计算复杂度是序列长度的平方级别,导致推理速度慢,内存占用高。尤其是在生成任务中,需要缓存大量的中间状态,进一步加剧了这个问题。现有方法通常采用剪枝、量化等方式压缩模型,但可能会损失模型性能。
核心思路:论文的核心思路是观察到Transformer中并非所有注意力头都负责建模长程依赖,部分注意力头主要关注局部或短程依赖关系。因此,可以将这些“近因感知”的注意力头替换为更高效的线性循环神经网络(RNN),从而降低计算和内存开销。
技术框架:RecurFormer的整体架构是在预训练的Transformer模型基础上,选择性地将部分注意力头替换为Mamba架构的线性循环层。具体流程包括:1) 分析预训练Transformer模型的注意力头,识别出“近因感知”的注意力头;2) 将这些注意力头替换为Mamba层;3) 使用持续训练(continual training)微调模型,以恢复因替换操作可能造成的性能损失。
关键创新:RecurFormer的关键创新在于发现了Transformer中注意力头的冗余性,并提出了一种选择性替换注意力头的策略。与直接压缩整个模型不同,RecurFormer保留了Transformer建模长程依赖的能力,同时利用线性循环网络的高效性处理局部依赖。这种混合架构在效率和性能之间取得了更好的平衡。
关键设计:论文的关键设计包括:1) 如何识别“近因感知”的注意力头:通过分析注意力权重的分布,例如计算注意力权重集中在query token附近的程度。2) 使用Mamba架构作为线性循环层:Mamba具有线性复杂度,并且能够有效地建模序列数据。3) 使用持续训练微调模型:在替换注意力头后,使用少量数据对模型进行微调,以恢复模型性能。
🖼️ 关键图片
📊 实验亮点
实验结果表明,RecurFormer在保持与原始Transformer模型相当的性能水平下,显著提升了推理效率。具体来说,RecurFormer在多个benchmark上取得了与原始模型相近的perplexity,同时降低了推理延迟和内存占用。例如,在处理长文本时,RecurFormer的推理速度提升了X倍(具体数值未知),内存占用降低了Y%(具体数值未知)。
🎯 应用场景
RecurFormer适用于需要处理长文本的各种自然语言处理任务,例如长文档摘要、机器翻译、代码生成、对话系统等。该方法可以显著降低推理成本,使得大型语言模型能够更高效地部署在资源受限的设备上,并加速长文本生成任务的推理速度,具有广泛的应用前景。
📄 摘要(原文)
Transformer-based large language models (LLMs) excel in modeling complex language patterns but face significant computational costs during inference, especially with long inputs due to the attention mechanism's memory overhead. We observe that certain attention heads exhibit a distribution where the attention weights concentrate on tokens near the query token, termed as recency aware, which focuses on local and short-range dependencies. Leveraging this insight, we propose RecurFormer, a novel architecture that replaces these attention heads with linear recurrent neural networks (RNNs), specifically the Mamba architecture. This replacement reduces the cache size without evicting tokens, thus maintaining generation quality. RecurFormer retains the ability to model long-range dependencies through the remaining attention heads and allows for reusing pre-trained Transformer-based LLMs weights with continual training. Experiments demonstrate that RecurFormer matches the original model's performance while significantly enhancing inference efficiency. Our approach provides a practical solution to the computational challenges of Transformer-based LLMs inference, making it highly attractive for tasks involving long inputs.