An Empirical Study of Mamba-based Language Models
作者: Roger Waleffe, Wonmin Byeon, Duncan Riach, Brandon Norick, Vijay Korthikanti, Tri Dao, Albert Gu, Ali Hatamizadeh, Sudhakar Singh, Deepak Narayanan, Garvit Kulshreshtha, Vartika Singh, Jared Casper, Jan Kautz, Mohammad Shoeybi, Bryan Catanzaro
分类: cs.LG, cs.CL
发布日期: 2024-06-12
💡 一句话要点
大规模Mamba语言模型实证研究:性能对比与混合架构探索
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: Mamba模型 状态空间模型 Transformer 混合架构 长上下文建模
📋 核心要点
- Transformer模型在长序列处理中面临计算复杂度高和内存需求大的挑战,限制了其应用。
- 论文探索了基于Mamba的SSM模型,并提出一种混合架构,旨在兼顾性能和效率。
- 实验表明,混合Mamba架构在多种任务上超越了Transformer,尤其在长上下文任务中表现出色。
📝 摘要(中文)
本文对基于Mamba的选择性状态空间模型(SSM)与Transformer模型进行了直接比较,旨在评估它们在大规模训练下的性能。Mamba模型旨在克服Transformer在序列长度上的二次计算复杂度和推理时key-value缓存带来的大内存需求等缺点。研究比较了80亿参数的Mamba、Mamba-2和Transformer模型,它们均在高达3.5T tokens的相同数据集上进行训练。此外,还评估了一种混合架构(Mamba-2-Hybrid),该架构由43%的Mamba-2层、7%的注意力层和50%的MLP层组成。结果表明,纯SSM在许多任务上与Transformer相当或超过Transformer,但在需要强复制或上下文学习能力的任务(如5-shot MMLU、Phonebook)或长上下文推理方面落后于Transformer。然而,8B Mamba-2-Hybrid在所有12个标准任务上的表现均优于8B Transformer(平均提升2.65分),并且在推理时生成tokens的速度预计快8倍。为了验证长上下文能力,还评估了扩展到支持16K、32K和128K序列的Mamba-2-Hybrid和Transformer的变体。在额外的23个长上下文任务中,混合模型继续与Transformer的表现相当或超过Transformer。为了方便进一步研究,作者发布了模型检查点以及用于训练模型的代码,作为NVIDIA Megatron-LM项目的一部分。
🔬 方法详解
问题定义:现有Transformer模型在处理长序列时,计算复杂度呈二次方增长,并且推理时需要大量的key-value缓存,导致内存需求巨大。这限制了Transformer在长文本建模和实时推理等场景的应用。Mamba等SSM模型旨在解决这些问题,但其在大规模训练下的性能表现需要进一步研究。
核心思路:论文的核心思路是通过实证研究,直接比较Mamba、Mamba-2和Transformer模型在大规模训练下的性能差异,并探索一种混合架构,结合Mamba的效率和Transformer的优势。这种混合架构旨在在各种任务上实现更好的性能和效率。
技术框架:研究采用了三种主要的模型架构:纯Mamba模型、纯Transformer模型和Mamba-2-Hybrid混合模型。所有模型都在相同的数据集上进行训练,并使用相同的训练流程。混合模型由Mamba-2层、注意力层和MLP层组成,比例分别为43%、7%和50%。评估涵盖了标准语言建模任务和长上下文任务。
关键创新:关键创新在于对大规模Mamba模型进行了全面的实证研究,并提出了Mamba-2-Hybrid混合架构。这种混合架构结合了Mamba的效率和Transformer的上下文学习能力,在多种任务上取得了优异的性能。与纯Mamba模型相比,混合模型在需要强复制或上下文学习能力的任务上表现更好。
关键设计:Mamba-2-Hybrid模型的关键设计在于混合了Mamba-2层、注意力层和MLP层。具体比例为43% Mamba-2, 7% attention, and 50% MLP。这种混合比例的选择可能基于经验观察和实验结果。此外,研究还探索了不同序列长度(16K、32K和128K)下的模型性能,以评估其长上下文处理能力。
🖼️ 关键图片
📊 实验亮点
实验结果表明,8B Mamba-2-Hybrid模型在12个标准任务上的平均表现优于8B Transformer模型2.65分,并且在推理时生成tokens的速度预计快8倍。在长上下文任务中,混合模型也与Transformer的表现相当或超过Transformer。这些结果表明,混合Mamba架构具有很强的竞争力,有望成为Transformer的替代方案。
🎯 应用场景
该研究成果可应用于各种自然语言处理任务,尤其是在需要处理长文本和实时推理的场景,如机器翻译、文本摘要、对话系统和代码生成等。混合Mamba架构有望在资源受限的环境中实现高效的语言建模,并加速相关应用的部署。
📄 摘要(原文)
Selective state-space models (SSMs) like Mamba overcome some of the shortcomings of Transformers, such as quadratic computational complexity with sequence length and large inference-time memory requirements from the key-value cache. Moreover, recent studies have shown that SSMs can match or exceed the language modeling capabilities of Transformers, making them an attractive alternative. In a controlled setting (e.g., same data), however, studies so far have only presented small scale experiments comparing SSMs to Transformers. To understand the strengths and weaknesses of these architectures at larger scales, we present a direct comparison between 8B-parameter Mamba, Mamba-2, and Transformer models trained on the same datasets of up to 3.5T tokens. We also compare these models to a hybrid architecture consisting of 43% Mamba-2, 7% attention, and 50% MLP layers (Mamba-2-Hybrid). Using a diverse set of tasks, we answer the question of whether Mamba models can match Transformers at larger training budgets. Our results show that while pure SSMs match or exceed Transformers on many tasks, they lag behind Transformers on tasks which require strong copying or in-context learning abilities (e.g., 5-shot MMLU, Phonebook) or long-context reasoning. In contrast, we find that the 8B Mamba-2-Hybrid exceeds the 8B Transformer on all 12 standard tasks we evaluated (+2.65 points on average) and is predicted to be up to 8x faster when generating tokens at inference time. To validate long-context capabilities, we provide additional experiments evaluating variants of the Mamba-2-Hybrid and Transformer extended to support 16K, 32K, and 128K sequences. On an additional 23 long-context tasks, the hybrid model continues to closely match or exceed the Transformer on average. To enable further study, we release the checkpoints as well as the code used to train our models as part of NVIDIA's Megatron-LM project.