Distilling to Hybrid Attention Models via KL-Guided Layer Selection
作者: Yanhong Li, Songlin Yang, Shawn Tan, Mayank Mishra, Rameswar Panda, Jiawei Zhou, Yoon Kim
分类: cs.CL, cs.AI
发布日期: 2025-12-23
💡 一句话要点
提出基于KL散度的层选择方法,用于将Softmax注意力Transformer蒸馏为混合注意力模型。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 模型蒸馏 注意力机制 线性注意力 知识迁移 层选择 KL散度 大型语言模型 推理效率
📋 核心要点
- 现有方法在将Softmax注意力Transformer蒸馏为混合架构时,层选择策略效率低或依赖特定数据集。
- 论文提出一种基于KL散度的层重要性评分方法,指导线性注意力层的选择,提升蒸馏效率。
- 实验表明,该方法优于均匀交错和依赖诊断数据集的方法,能更有效地蒸馏模型。
📝 摘要(中文)
本文提出了一种简单高效的层选择方法,用于将预训练的Softmax注意力Transformer蒸馏为更高效的混合架构,该架构交替使用Softmax和线性注意力层。这种方法旨在提高大型语言模型的推理效率,而无需从头开始进行昂贵的预训练。该方法使用从少量通用文本数据训练中获得的层重要性得分来决定哪些层转换为线性注意力变体。在选定层之后,我们使用最近的蒸馏流程(RADLADS),该流程包括注意力权重转移、隐藏状态对齐、基于KL散度的分布匹配,以及少量的微调。实验表明,该方法比现有的层选择方法更有效,包括基于固定比例均匀交错线性注意力的启发式方法,以及依赖于专门诊断数据集的更复杂的方法。
🔬 方法详解
问题定义:论文旨在解决如何高效地将预训练的Softmax注意力Transformer模型蒸馏成混合注意力模型的问题。现有方法,如均匀交错线性注意力层或依赖特定诊断数据集的方法,在层选择方面效率较低或泛化性不足,导致蒸馏后的模型性能提升有限。这些方法未能充分利用模型自身的特性来指导层选择,从而影响了蒸馏效果。
核心思路:论文的核心思路是利用少量通用文本数据训练得到的层重要性得分来指导线性注意力层的选择。通过计算每一层的重要性,可以确定哪些层对模型的性能影响更大,从而优先将这些层转换为线性注意力变体。这种方法能够更有效地利用模型自身的知识,避免盲目地进行层转换,从而提高蒸馏效率和模型性能。
技术框架:整体框架包括两个主要阶段:层选择阶段和蒸馏阶段。在层选择阶段,首先使用少量通用文本数据训练原始的Softmax注意力Transformer模型,然后计算每一层的重要性得分。基于这些得分,选择一部分层转换为线性注意力层。在蒸馏阶段,采用RADLADS流程,包括注意力权重转移、隐藏状态对齐、基于KL散度的分布匹配,以及少量的微调。
关键创新:最重要的技术创新点在于提出了一种基于KL散度的层重要性评分方法,用于指导线性注意力层的选择。与现有方法相比,该方法能够更有效地利用模型自身的知识,避免盲目地进行层转换,从而提高蒸馏效率和模型性能。这种方法不需要依赖特定的诊断数据集,具有更好的泛化性。
关键设计:层重要性得分的计算基于KL散度,衡量了每一层输出分布与原始模型输出分布之间的差异。选择KL散度较大的层进行转换,因为这些层对模型的性能影响更大。蒸馏阶段采用RADLADS流程,包括注意力权重转移、隐藏状态对齐、基于KL散度的分布匹配,以及少量的微调。KL散度也被用于分布匹配,确保蒸馏后的模型能够尽可能地逼近原始模型的输出分布。
🖼️ 关键图片
📊 实验亮点
论文提出的方法在层选择方面优于现有的均匀交错和依赖诊断数据集的方法。实验结果表明,使用该方法蒸馏后的模型在性能上取得了显著提升,同时保持了较高的推理效率。具体的性能数据和对比基线在论文中进行了详细的展示。
🎯 应用场景
该研究成果可应用于各种需要高效推理的大型语言模型场景,例如移动设备上的自然语言处理、实时对话系统和资源受限环境下的文本生成。通过将大型模型蒸馏为更小的混合注意力模型,可以显著降低计算成本和内存占用,从而实现更广泛的应用。
📄 摘要(原文)
Distilling pretrained softmax attention Transformers into more efficient hybrid architectures that interleave softmax and linear attention layers is a promising approach for improving the inference efficiency of LLMs without requiring expensive pretraining from scratch. A critical factor in the conversion process is layer selection, i.e., deciding on which layers to convert to linear attention variants. This paper describes a simple and efficient recipe for layer selection that uses layer importance scores derived from a small amount of training on generic text data. Once the layers have been selected we use a recent pipeline for the distillation process itself \citep[RADLADS;][]{goldstein2025radlads}, which consists of attention weight transfer, hidden state alignment, KL-based distribution matching, followed by a small amount of finetuning. We find that this approach is more effective than existing approaches for layer selection, including heuristics that uniformly interleave linear attentions based on a fixed ratio, as well as more involved approaches that rely on specialized diagnostic datasets.