Transformers to SSMs: Distilling Quadratic Knowledge to Subquadratic Models
作者: Aviv Bick, Kevin Y. Li, Eric P. Xing, J. Zico Kolter, Albert Gu
分类: cs.LG, cs.AI
发布日期: 2024-08-19 (更新: 2025-02-08)
💡 一句话要点
提出MOHAWK方法以将Transformer知识蒸馏至子二次模型
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 知识蒸馏 状态空间模型 Transformer 子二次模型 计算效率 自然语言处理 模型压缩
📋 核心要点
- 现有的Transformer模型在推理时由于自注意力机制的二次时间复杂度面临性能瓶颈。
- 本文提出的MOHAWK方法通过逐步匹配SSM中的混合矩阵和隐藏单元,实现了Transformer到SSM的知识蒸馏。
- 实验结果表明,使用少量训练数据的Phi-Mamba在性能上显著优于以往的非Transformer模型,展示了新的模型构建思路。
📝 摘要(中文)
Transformer架构已成为语言建模等领域的主流,但在许多推理场景中由于其二次时间复杂度的自注意力机制而受到限制。最近提出的子二次架构如Mamba显示出潜力,但其预训练所需的计算资源远低于最强的Transformer模型。本文提出了一种方法,能够将预训练的Transformer架构蒸馏为状态空间模型(SSMs)等替代架构。我们的方法MOHAWK通过匹配SSM中的不同粒度,逐步蒸馏Transformer架构,最终在使用少于1%训练数据的情况下,Phi-Mamba在性能上显著优于所有过去的开源非Transformer模型。
🔬 方法详解
问题定义:本文旨在解决Transformer模型在推理时的计算效率问题,尤其是其二次时间复杂度导致的性能限制。现有的子二次模型虽然有潜力,但通常预训练资源不足,难以与强大的Transformer模型竞争。
核心思路:我们提出的MOHAWK方法通过将Transformer和SSM视为对token序列应用不同形式的混合矩阵,逐步蒸馏Transformer架构。该方法首先匹配混合矩阵,然后匹配每个块的隐藏单元,最后匹配端到端的预测。
技术框架:MOHAWK的整体架构包括三个主要阶段:1) 匹配混合矩阵,2) 匹配隐藏单元,3) 匹配最终预测。每个阶段都针对不同的粒度进行优化,以确保知识的有效传递。
关键创新:MOHAWK的核心创新在于其逐步蒸馏的策略,使得子二次模型能够有效利用Transformer的预训练知识,显著提升了模型的性能。与现有方法相比,MOHAWK在计算资源利用上更为高效。
关键设计:在参数设置上,MOHAWK使用了3B和5B的token进行训练,损失函数设计上强调了不同粒度的匹配,确保了模型在各个阶段的学习效果。
🖼️ 关键图片
📊 实验亮点
实验结果显示,使用MOHAWK方法的Phi-Mamba在仅用3B tokens的情况下,性能显著优于以往的非Transformer模型,展示了在计算资源极为有限的情况下仍能取得优异表现的潜力。
🎯 应用场景
该研究的潜在应用领域包括自然语言处理、语音识别和其他需要高效推理的任务。通过将Transformer的知识蒸馏至更高效的模型,MOHAWK为资源受限的环境提供了新的解决方案,具有重要的实际价值和广泛的应用前景。
📄 摘要(原文)
Transformer architectures have become a dominant paradigm for domains like language modeling but suffer in many inference settings due to their quadratic-time self-attention. Recently proposed subquadratic architectures, such as Mamba, have shown promise, but have been pretrained with substantially less computational resources than the strongest Transformer models. In this work, we present a method that is able to distill a pretrained Transformer architecture into alternative architectures such as state space models (SSMs). The key idea to our approach is that we can view both Transformers and SSMs as applying different forms of mixing matrices over the token sequences. We can thus progressively distill the Transformer architecture by matching different degrees of granularity in the SSM: first matching the mixing matrices themselves, then the hidden units at each block, and finally the end-to-end predictions. Our method, called MOHAWK, is able to distill a Mamba-2 variant based on the Phi-1.5 architecture (Phi-Mamba) using only 3B tokens and a hybrid version (Hybrid Phi-Mamba) using 5B tokens. Despite using less than 1% of the training data typically used to train models from scratch, Phi-Mamba boasts substantially stronger performance compared to all past open-source non-Transformer models. MOHAWK allows models like SSMs to leverage computational resources invested in training Transformer-based architectures, highlighting a new avenue for building such models.