The Mamba in the Llama: Distilling and Accelerating Hybrid Models

📄 arXiv: 2408.15237v4 📥 PDF

作者: Junxiong Wang, Daniele Paliotta, Avner May, Alexander M. Rush, Tri Dao

分类: cs.LG, cs.AI

发布日期: 2024-08-27 (更新: 2025-06-27)

备注: NeurIPS 2024. v4 updates: mention concurrent work of speculative decoding for SSM

🔗 代码/项目: GITHUB | GITHUB


💡 一句话要点

提出Transformer蒸馏至线性RNN的混合模型,并用硬件感知推测解码加速推理。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 线性RNN Mamba Transformer 知识蒸馏 推测解码 混合模型 模型加速

📋 核心要点

  1. 现有Transformer模型部署成本高昂,线性RNN如Mamba具有更优的部署特性,但训练成本高。
  2. 通过知识蒸馏,将预训练Transformer模型的知识迁移到混合模型(部分Transformer+Mamba),降低训练成本。
  3. 提出硬件感知的推测解码算法,加速Mamba和混合模型的推理速度,进一步提升部署效率。

📝 摘要(中文)

本文研究了将大规模Transformer模型转换为线性RNN以进行部署的挑战。通过重用注意力层的线性投影权重,将大型Transformer模型蒸馏成线性RNN是可行的。由此产生的混合模型,保留了四分之一的注意力层,在聊天基准测试中实现了与原始Transformer相当的性能,并且在聊天和通用基准测试中均优于从头开始训练的开源混合Mamba模型。此外,本文还提出了一种硬件感知的推测解码算法,加速了Mamba和混合模型的推理速度。结果表明,在有限的计算资源下,可以移除大部分原始注意力层,并更有效地生成模型。最佳模型从Llama3-8B-Instruct蒸馏而来,在AlpacaEval 2上针对GPT-4的长度控制胜率为29.61,在MT-Bench上为7.35,超过了最佳的8B规模指令调整线性RNN模型。蒸馏模型还表现出自然的长度外推能力,在20倍蒸馏长度的干草堆中寻针测试中表现出几乎完美的准确性。代码和预训练检查点已开源。

🔬 方法详解

问题定义:现有的大规模Transformer模型虽然性能强大,但在部署时面临计算资源需求高、推理速度慢等问题。线性RNN(如Mamba)虽然具有更高效的推理特性,但从头开始训练的成本很高,且性能可能不如预训练的Transformer模型。因此,如何利用已有的Transformer模型知识,构建高效且高性能的线性RNN模型是一个关键问题。

核心思路:本文的核心思路是通过知识蒸馏,将预训练的Transformer模型的知识迁移到混合模型中。具体来说,利用Transformer模型中注意力层的线性投影权重,初始化线性RNN模型,从而避免从头开始训练。同时,保留部分Transformer层,以提升模型的性能。此外,还提出了一种硬件感知的推测解码算法,进一步加速推理过程。

技术框架:整体框架包含两个主要阶段:1) 模型蒸馏:将预训练的Transformer模型(如Llama3)蒸馏成混合模型,该模型包含部分Transformer层和Mamba层。蒸馏过程中,重用Transformer注意力层的线性投影权重来初始化Mamba层。2) 推理加速:使用硬件感知的推测解码算法,加速Mamba和混合模型的推理速度。该算法根据硬件特性,动态调整推测步长,从而优化推理效率。

关键创新:本文的关键创新在于:1) 提出了一种有效的Transformer到线性RNN的蒸馏方法,通过重用注意力层的线性投影权重,降低了训练成本。2) 设计了一种混合模型架构,结合了Transformer和Mamba的优点,在性能和效率之间取得了平衡。3) 提出了一种硬件感知的推测解码算法,进一步加速了Mamba和混合模型的推理速度。

关键设计:在模型蒸馏过程中,选择合适的Transformer层数和Mamba层数是一个关键设计。实验表明,保留四分之一的Transformer层可以获得较好的性能。在推测解码算法中,需要根据硬件特性(如GPU的内存带宽和计算能力)调整推测步长。损失函数主要采用交叉熵损失,用于衡量模型预测结果与真实标签之间的差异。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,从Llama3-8B-Instruct蒸馏得到的混合模型在AlpacaEval 2上针对GPT-4的长度控制胜率为29.61,在MT-Bench上为7.35,超过了最佳的8B规模指令调整线性RNN模型。此外,该模型在20倍蒸馏长度的干草堆中寻针测试中表现出几乎完美的准确性,表明其具有良好的长度外推能力。

🎯 应用场景

该研究成果可应用于各种需要高效部署的大语言模型场景,例如移动设备上的聊天机器人、边缘计算设备上的智能助手等。通过将大型Transformer模型蒸馏成更小、更快的混合模型,可以在资源受限的环境中实现高性能的自然语言处理应用。此外,该方法还可以促进线性RNN模型在实际应用中的普及。

📄 摘要(原文)

Linear RNN architectures, like Mamba, can be competitive with Transformer models in language modeling while having advantageous deployment characteristics. Given the focus on training large-scale Transformer models, we consider the challenge of converting these pretrained models for deployment. We demonstrate that it is feasible to distill large Transformers into linear RNNs by reusing the linear projection weights from attention layers with academic GPU resources. The resulting hybrid model, which incorporates a quarter of the attention layers, achieves performance comparable to the original Transformer in chat benchmarks and outperforms open-source hybrid Mamba models trained from scratch with trillions of tokens in both chat benchmarks and general benchmarks. Moreover, we introduce a hardware-aware speculative decoding algorithm that accelerates the inference speed of Mamba and hybrid models. Overall we show how, with limited computation resources, we can remove many of the original attention layers and generate from the resulting model more efficiently. Our top-performing model, distilled from Llama3-8B-Instruct, achieves a 29.61 length-controlled win rate on AlpacaEval 2 against GPT-4 and 7.35 on MT-Bench, surpassing the best 8B scale instruction-tuned linear RNN model. We also find that the distilled model has natural length extrapolation, showing almost perfect accuracy in the needle-in-a-haystack test at 20x the distillation length. Code and pre-trained checkpoints are open-sourced at https://github.com/jxiw/MambaInLlama and https://github.com/itsdaniele/speculative_mamba.