Samba: Simple Hybrid State Space Models for Efficient Unlimited Context Language Modeling
作者: Liliang Ren, Yang Liu, Yadong Lu, Yelong Shen, Chen Liang, Weizhu Chen
分类: cs.CL, cs.LG
发布日期: 2024-06-11 (更新: 2025-02-28)
备注: Accepted by ICLR 2025. Camera-ready Version
🔗 代码/项目: GITHUB
💡 一句话要点
Samba:一种高效的无限上下文语言建模的简单混合状态空间模型
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 长序列建模 状态空间模型 注意力机制 混合架构 语言模型
📋 核心要点
- 现有序列建模方法在处理无限上下文长度时面临计算复杂度高或长度泛化能力有限的挑战。
- Samba通过结合Mamba和滑动窗口注意力,选择性压缩序列信息并精确回忆近期记忆,实现高效建模。
- 实验表明,Samba在多个基准测试中优于现有模型,并在长上下文零样本学习和外推方面表现出色。
📝 摘要(中文)
高效地建模具有无限上下文长度的序列一直是一个具有挑战性的问题。以往的方法要么遭受二次计算复杂度,要么在长度泛化中受到有限的外推能力限制。本文提出Samba,一种简单的混合架构,它逐层地将选择性状态空间模型(SSM)Mamba与滑动窗口注意力(SWA)相结合。Samba选择性地将给定的序列压缩为循环隐藏状态,同时仍然保持使用注意力机制精确回忆最近记忆的能力。我们将Samba扩展到3.8B参数,使用3.2T训练tokens,并证明它在各种基准测试中显著优于最先进的模型。在4K长度的序列上预训练后,Samba在高达1M的上下文长度中以零样本方式显示出改进的困惑度。当在4K长度的序列上进行微调时,Samba有效地外推到256K上下文长度,在Passkey Retrieval任务上具有完美的记忆召回,并且在具有挑战性的Phonebook任务上表现出优于全注意力模型的检索外推能力。作为一种线性时间序列模型,对于128K长度的用户提示,Samba实现了比具有分组查询注意力的Transformer高3.73倍的吞吐量,并且在生成具有无限流式传输的64K tokens时实现了3.64倍的加速。我们在开源数据上训练的代码可在https://github.com/microsoft/Samba公开获取。
🔬 方法详解
问题定义:论文旨在解决长序列建模中,现有Transformer模型计算复杂度高,以及状态空间模型(SSM)长度外推能力不足的问题。Transformer的注意力机制复杂度是序列长度的平方级别,难以处理超长序列。而单纯的SSM虽然复杂度是线性的,但在长序列上的记忆能力和精度有所欠缺。
核心思路:Samba的核心思路是结合Mamba(一种选择性状态空间模型)和滑动窗口注意力(SWA)的优势。Mamba负责压缩序列信息到循环隐藏状态,实现高效的长程依赖建模;SWA则负责精确回忆最近的记忆,提升短期记忆的准确性。通过层级混合这两种机制,Samba在计算效率和记忆能力之间取得平衡。
技术框架:Samba的整体架构是一个层级的混合模型。每一层都包含Mamba模块和SWA模块。输入序列首先经过Mamba模块,该模块将序列压缩成循环隐藏状态。然后,SWA模块利用注意力机制对最近的序列片段进行精确建模。Mamba和SWA的输出可以进行加权融合,或者通过其他方式进行组合,形成该层的最终输出。多层这样的混合结构堆叠起来,就构成了完整的Samba模型。
关键创新:Samba的关键创新在于提出了一个简单有效的混合架构,将SSM和注意力机制有机结合。这种混合方式不仅继承了SSM的线性复杂度,还保留了注意力机制的精确记忆能力。此外,Samba的设计允许模型在长序列上进行高效的训练和推理,并且具有良好的长度外推能力。
关键设计:Samba的关键设计包括:1) Mamba模块的选择性扫描机制,允许模型根据输入动态调整状态转移矩阵;2) SWA模块的窗口大小和注意力头的数量,需要根据具体任务进行调整;3) Mamba和SWA输出的融合方式,例如简单的加权平均或者更复杂的门控机制;4) 损失函数的设计,通常采用交叉熵损失,并可能加入正则化项以防止过拟合。
🖼️ 关键图片
📊 实验亮点
Samba在多个基准测试中取得了显著的性能提升。例如,在长上下文语言建模任务中,Samba在高达1M的上下文长度下表现出改进的困惑度。在Passkey Retrieval任务中,Samba能够有效地外推到256K的上下文长度,并实现完美的记忆召回。此外,Samba在推理速度方面也优于Transformer模型,对于128K长度的用户提示,Samba实现了比具有分组查询注意力的Transformer高3.73倍的吞吐量。
🎯 应用场景
Samba在需要处理超长序列的语言建模任务中具有广泛的应用前景,例如:长文档摘要、代码生成、对话系统、以及需要理解长时间上下文的机器人控制等。其高效的计算性能和良好的长度外推能力,使得它能够处理传统Transformer模型难以胜任的任务,并有望推动相关领域的发展。
📄 摘要(原文)
Efficiently modeling sequences with infinite context length has long been a challenging problem. Previous approaches have either suffered from quadratic computational complexity or limited extrapolation ability in length generalization. In this work, we present Samba, a simple hybrid architecture that layer-wise combines Mamba, a selective State Space Model (SSM), with Sliding Window Attention (SWA). Samba selectively compresses a given sequence into recurrent hidden states while still maintaining the ability to precisely recall recent memories with the attention mechanism. We scale Samba up to 3.8B parameters with 3.2T training tokens and demonstrate that it significantly outperforms state-of-the-art models across a variety of benchmarks. Pretrained on sequences of 4K length, Samba shows improved perplexity in context lengths of up to 1M in zero-shot. When finetuned on 4K-length sequences, Samba efficiently extrapolates to a 256K context length with perfect memory recall on the Passkey Retrieval task, and exhibits superior retrieval extrapolation on the challenging Phonebook task compared to full-attention models. As a linear-time sequence model, Samba achieves a 3.73x higher throughput compared to Transformers with grouped-query attention for user prompts of 128K length, and a 3.64x speedup when generating 64K tokens with unlimited streaming. Our code for training on open source data is publicly available at https://github.com/microsoft/Samba.