Unmasking On-Policy Distillation: Where It Helps, Where It Hurts, and Why
作者: Mohammadreza Armandpour, Fatih Ilhan, David Harrison, Ajay Jaiswal, Duc N. M Hoang, Fartash Faghri, Yizhe Zhang, Minsik Cho, Mehrdad Farajtabar
分类: cs.LG, cs.AI
发布日期: 2026-05-11
💡 一句话要点
提出一种无需训练的诊断框架,通过梯度对齐分析揭示策略内蒸馏在推理模型训练中的作用机制。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 策略内蒸馏 推理模型 梯度分析 思维链 模型诊断 知识蒸馏
📋 核心要点
- 现有蒸馏方法依赖昂贵的训练实验,且宏观性能指标无法揭示不同教师模型或上下文在逐标记层面的具体影响。
- 提出一种无需训练的诊断框架,通过计算理想梯度与蒸馏梯度的余弦相似度,量化蒸馏信号在推理过程中的有效性。
- 实验表明蒸馏在错误路径上更具引导价值,且最优蒸馏策略高度依赖于模型能力与任务特性的动态匹配。
📝 摘要(中文)
策略内蒸馏(On-policy distillation)为推理模型的训练提供了密集的逐标记监督信号,但其在何种条件下有效或有害尚不明确。目前,选择合适的教师模型、确定自蒸馏的最佳上下文以及评估标记级信号的有效性,通常需要昂贵的训练实验,且宏观性能指标掩盖了细粒度的动态变化。本文引入了一种无需训练的诊断框架,可在逐标记、逐问题和逐教师的最高分辨率下进行分析。我们推导了理想的逐节点梯度,即能最大化学生模型成功概率的参数更新,并开发了一种可扩展的目标展开算法来高效估计该梯度。通过计算理想梯度与蒸馏梯度之间的余弦相似度(梯度对齐得分),我们量化了特定配置对理想信号的逼近程度。研究发现,在错误路径上,蒸馏引导与理想梯度的对齐度显著高于正确路径;此外,最优蒸馏上下文取决于学生模型能力与任务的耦合,不存在通用的最优配置,这强调了进行逐任务、逐标记诊断分析的必要性。
🔬 方法详解
问题定义:论文旨在解决策略内蒸馏中“何时有效、何时有害”的黑盒问题。现有方法缺乏对蒸馏信号质量的细粒度评估手段,导致研究者难以判断特定教师或上下文是否真正提升了学生的推理能力。
核心思路:引入“理想梯度”概念,即能使学生模型在特定步骤成功概率最大化的理论最优更新方向。通过衡量实际蒸馏梯度与该理想梯度的对齐程度,实现对蒸馏信号质量的定量诊断。
技术框架:框架包含三个核心环节:首先定义理想逐节点梯度;其次利用可扩展的目标展开(Targeted-rollout)算法高效估计该梯度;最后计算梯度对齐得分(Gradient Alignment Score),即理想梯度与蒸馏梯度间的余弦相似度。
关键创新:最大的创新在于将蒸馏评估从“结果导向”转变为“过程导向”。通过无需训练的诊断手段,在不进行完整微调的情况下,即可评估不同蒸馏配置对模型学习轨迹的潜在贡献。
关键设计:核心技术细节在于目标展开算法,它能够处理长思维链中的梯度估计问题。此外,通过对比正确与错误路径上的对齐得分,揭示了教师信号在学生已掌握知识点上可能产生的噪声干扰效应。
🖼️ 关键图片
📊 实验亮点
研究发现蒸馏信号在学生模型表现较差的路径上具有更高的对齐度,即蒸馏在纠错时效果显著;而在学生已能正确推理的路径上,教师信号往往引入噪声。实验证明不存在“一刀切”的最优蒸馏配置,必须根据任务特性与模型规模进行定制化诊断。
🎯 应用场景
该研究适用于大语言模型(LLM)的推理能力增强,特别是在思维链(CoT)微调阶段。它为模型开发者提供了一种高效的诊断工具,用于筛选最优的教师模型、确定蒸馏的上下文窗口,并优化数据合成策略,从而降低训练成本并提升模型推理的鲁棒性。
📄 摘要(原文)
On-policy distillation offers dense, per-token supervision for training reasoning models; however, it remains unclear under which conditions this signal is beneficial and under which it is detrimental. Which teacher model should be used, and in the case of self-distillation, which specific context should serve as the supervisory signal? Does the optimal choice vary from one token to the next? At present, addressing these questions typically requires costly training runs whose aggregate performance metrics obscure the dynamics at the level of individual tokens. We introduce a training-free diagnostic framework that operates at the highest resolution: per token, per question, and per teacher. We derive an ideal per-node gradient defined as the parameter update that maximally increases the student's probability of success. We then develop a scalable targeted-rollout algorithm to estimate this gradient efficiently, even for long chains of intermediate thoughts. The gradient alignment score, defined as the cosine similarity between this ideal gradient and any given distillation gradient, quantifies the extent to which a particular configuration approximates the ideal signal. Across a range of self-distillation settings and external teacher models, we observe that distillation guidance exhibits substantially higher alignment with the ideal on incorrect rollouts than on correct ones, where the student already performs well and the teacher's signal tends to become noisy. Furthermore, we find that the optimal distillation context depends jointly on the student model's capacity and the target task, and that no single universally effective configuration emerges. These findings motivate the use of per-task, per-token diagnostic analyses for distillation.