RAD: Redundancy-Aware Distillation for Hybrid Models via Self-Speculative Decoding

📄 arXiv: 2505.22135v1 📥 PDF

作者: Yuichiro Hoshino, Hideyuki Tachibana, Muneyoshi Inahara, Hiroto Takegawa

分类: cs.CL, cs.LG

发布日期: 2025-05-28

备注: 26 pages


💡 一句话要点

提出RAD:通过自推测解码实现混合模型冗余感知蒸馏

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 混合模型 知识蒸馏 自推测解码 冗余感知 Transformer 状态空间模型 模型优化

📋 核心要点

  1. 现有混合模型优化面临Transformer组件冗余的挑战,影响性能和效率。
  2. RAD通过自推测解码诊断冗余注意力层,并选择性地替换为SSM组件,再进行针对性蒸馏。
  3. 实验表明,RAD在数学和编码任务上显著优于基线模型,且收敛速度更快,性能更高。

📝 摘要(中文)

本文提出了一种名为RAD(Redundancy-Aware Distillation,冗余感知蒸馏)的新框架,用于优化混合模型(Transformer和状态空间模型SSM的结合)。RAD利用自推测解码作为诊断工具,识别模型中冗余的注意力层。然后,选择性地将这些层替换为SSM组件,并进行有针对性的(自)蒸馏。RAD侧重于将知识转移到被识别为冗余的组件上,同时考虑架构变化和特定的权重初始化策略。实验表明,在数学和编码任务上,使用RAD的自蒸馏显著优于原始基线模型。此外,RAD在标准知识蒸馏设置中也有效,与基线方法相比,收敛速度提高了约2倍。值得注意的是,即使使用较小的Llama-3.1 8B教师模型,RAD在GSM8K和CRUX上的得分分别为71.27和28.25,也明显高于从Llama-3.1 70B教师模型蒸馏得到的基线模型(GSM8K:46.17,CRUX:22.75)。RAD为混合模型的有效优化和性能提升提供了一条新途径。

🔬 方法详解

问题定义:论文旨在解决混合模型(Transformer和SSM)中Transformer组件冗余的问题。现有方法难以有效识别和处理这些冗余,导致模型效率低下,性能受限。尤其是在知识蒸馏场景下,如何将大型教师模型的知识高效地转移到混合结构的student模型是一个挑战。

核心思路:论文的核心思路是利用自推测解码来诊断Transformer中的冗余注意力层。自推测解码能够评估每个注意力层对最终预测的贡献,从而识别出可以被更高效的SSM组件替代的冗余层。通过选择性地替换这些层并进行知识蒸馏,可以优化混合模型的性能和效率。

技术框架:RAD框架包含以下主要阶段:1) 冗余识别:使用自推测解码识别Transformer中的冗余注意力层。2) 组件替换:将识别出的冗余层替换为SSM组件。3) 知识蒸馏:使用教师模型或自身作为教师,对替换后的混合模型进行知识蒸馏,重点关注被替换的组件。整体流程旨在优化混合模型的性能,同时减少计算开销。

关键创新:RAD的关键创新在于使用自推测解码作为诊断工具来识别Transformer中的冗余层。与传统的知识蒸馏方法不同,RAD不是简单地将整个教师模型的知识转移到学生模型,而是有选择性地关注那些对学生模型性能提升贡献最大的组件。这种冗余感知的方法能够更有效地利用教师模型的知识,并避免不必要的计算开销。

关键设计:RAD的关键设计包括:1) 自推测解码的实现细节:如何设计自推测解码过程,以准确评估每个注意力层的贡献。2) SSM组件的选择和配置:选择哪种SSM组件以及如何配置其参数,以最大程度地替代冗余的注意力层。3) 知识蒸馏的损失函数:设计合适的损失函数,以引导学生模型学习教师模型的知识,并优化被替换的SSM组件的性能。4) 权重初始化策略:针对替换后的SSM组件,采用特定的权重初始化策略,以加速训练过程并提高模型性能。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

RAD在GSM8K和CRUX数据集上取得了显著的性能提升。使用Llama-3.1 8B作为教师模型,RAD在GSM8K上达到了71.27的得分,在CRUX上达到了28.25的得分,远高于使用Llama-3.1 70B作为教师模型的基线模型(GSM8K:46.17,CRUX:22.75)。此外,RAD在标准知识蒸馏设置中实现了约2倍的收敛速度提升。

🎯 应用场景

RAD方法可应用于各种需要高效和高性能模型的场景,例如自然语言处理、语音识别和计算机视觉。通过优化混合模型,RAD可以降低计算成本,提高模型推理速度,从而使这些模型更易于部署在资源受限的设备上。此外,RAD还可以促进更大规模模型的知识迁移,加速AI技术的普及。

📄 摘要(原文)

Hybrid models combining Transformers and State Space Models (SSMs) are promising for balancing performance and efficiency. However, optimizing these hybrid models, particularly by addressing the potential redundancy inherent within the Transformer components, remains a significant challenge. In this paper, we propose RAD (Redundancy-Aware Distillation), a novel framework that uses self-speculative decoding as a diagnostic tool to identify redundant attention layers within the model. These identified layers are then selectively replaced with SSM components, followed by targeted (self-)distillation. Specifically, RAD focuses knowledge transfer on the components identified as redundant, considering architectural changes and specific weight initialization strategies. We experimentally demonstrate that self-distillation using RAD significantly surpasses the performance of the original base model on mathematical and coding tasks. Furthermore, RAD is also effective in standard knowledge distillation settings, achieving up to approximately 2x faster convergence compared to baseline methods. Notably, while a baseline model distilled from a Llama-3.1 70B teacher achieves scores of 46.17 on GSM8K and 22.75 on CRUX, RAD achieves significantly higher scores of 71.27 on GSM8K and 28.25 on CRUX, even when using a much smaller Llama-3.1 8B teacher. RAD offers a new pathway for efficient optimization and performance enhancement in the distillation of hybrid models.