KL for a KL: On-Policy Distillation with Control Variate Baseline

📄 arXiv: 2605.07865v1 📥 PDF

作者: Minjae Oh, Sangjun Song, Gyubin Choi, Yunho Choi, Yohan Jo

分类: cs.LG, cs.AI, cs.CL

发布日期: 2026-05-08


💡 一句话要点

提出vOPD方法:通过引入控制变量基线,解决在线策略蒸馏中的梯度方差不稳定问题。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 在线策略蒸馏 强化学习 方差缩减 大语言模型 模型压缩 推理能力提升

📋 核心要点

  1. 在线策略蒸馏(OPD)依赖单样本蒙特卡洛估计,导致梯度方差过大,训练过程极不稳定。
  2. 将OPD重构为策略梯度强化学习问题,利用解析形式的价值函数作为控制变量基线,实现方差缩减。
  3. 在数学与科学推理任务中,vOPD在保持计算高效的同时,性能显著优于传统OPD并媲美全词表基线。

📝 摘要(中文)

在线策略蒸馏(OPD)已成为大语言模型(尤其是推理领域)的主流后训练范式。然而,由于单样本蒙特卡洛估计器的高梯度方差,OPD在实践中表现不稳定,且缺乏成熟的训练方案。本文提出vOPD(带控制变量基线的在线策略蒸馏),将OPD建模为策略梯度强化学习,并引入强化学习文献中的控制变量基线(即价值函数)来稳定训练。研究表明,OPD的价值函数可解析为学生模型与教师模型之间的逐Token负反向KL散度,无需额外的Critic网络或推理开销。相比于计算全词表KL散度的高昂成本或Top-k截断带来的偏差,vOPD保留了轻量级的单样本估计器,通过减去价值函数作为基线,在保持梯度无偏的同时显著降低了方差。在数学和科学推理基准测试中,vOPD表现优于传统OPD,并达到了全词表基线的性能水平。

🔬 方法详解

问题定义:OPD旨在通过在线采样教师模型的输出分布来训练学生模型,但其梯度估计器方差极大,导致训练过程难以收敛且对超参数高度敏感,现有方法在计算效率与估计偏差之间难以平衡。

核心思路:借鉴强化学习中的方差缩减技术,引入控制变量(Control Variate)基线。论文发现OPD中的价值函数存在闭式解,即学生与教师模型间的逐Token负反向KL散度,这使得无需额外训练Critic即可实现方差控制。

技术框架:vOPD框架直接利用学生和教师模型的前向传播结果计算价值函数,将其作为基线从策略梯度中减去。该过程无需额外的推理步骤,保持了训练的轻量化特性。

关键创新:核心创新在于将KL散度作为控制变量基线,证明了该基线能使梯度估计保持无偏且方差更小。此外,引入Top-k近似基线进一步降低了计算开销,同时避免了传统截断方法带来的偏差问题。

关键设计:利用教师模型输出的对数概率与学生模型输出的对数概率之差作为价值函数项,通过在策略梯度更新公式中引入该项,有效抵消了采样带来的高方差噪声,实现了训练的稳定性提升。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验表明,vOPD在多个数学和科学推理基准测试中表现优异。相比于传统的OPD,vOPD在训练稳定性上有显著提升,且在计算开销几乎不变的情况下,性能达到了计算全词表KL散度的昂贵基线水平。Top-k近似策略进一步验证了其在资源受限环境下的高效性与鲁棒性。

🎯 应用场景

该方法主要应用于大语言模型的后训练阶段,特别是在数学、代码生成及科学推理等对逻辑严密性要求极高的领域。通过提升蒸馏过程的稳定性,vOPD能够帮助开发者更高效地将大型教师模型的推理能力迁移至轻量级学生模型,降低推理成本并加速模型部署。

📄 摘要(原文)

On-Policy Distillation (OPD) has emerged as a dominant post-training paradigm for large language models, especially for reasoning domains. However, OPD remains unstable in practice due to the high gradient variance of its single-sample Monte Carlo estimator, and recipes for stable training are still immature. We propose vOPD (On-Policy Distillation with a control variate baseline), which casts OPD as policy-gradient RL and stabilizes it by introducing a control variate baseline-canonically a value function -- from the RL literature. We show that the OPD value function admits a closed form as the per-token negative reverse KL divergence between the student and the teacher, available directly from the already-computed forward pass with no additional critic or inference. Existing stabilization methods either compute the full token-level reverse KL over the entire vocabulary, adding significant overhead, or restrict it to a top-k support, biasing the objective. vOPD instead preserves the lightweight single-sample estimator, subtracting the value function as a detached baseline to keep the gradient unbiased while reducing variance. Furthermore, we show that a top-k approximation of the baseline further lowers cost without compromising performance. Across mathematical and scientific reasoning benchmarks, vOPD consistently outperforms vanilla OPD and matches the most expensive full-vocabulary baseline, offering an efficient stabilization of On-Policy Distillation through principled RL variance reduction.