Taming the Tail: Stable LLM Reinforcement Learning via Dynamic Vocabulary Pruning
作者: Yingru Li, Jiawei Xu, Jiacai Liu, Yuxuan Tong, Ziniu Li, Tianle Cai, Ge Zhang, Qian Liu, Baoxiang Wang
分类: cs.LG, cs.AI, stat.ML
发布日期: 2025-12-28
💡 一句话要点
提出基于动态词汇剪枝的稳定LLM强化学习方法,解决训练-推理不匹配问题。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 强化学习 大型语言模型 词汇剪枝 训练稳定性 概率分布匹配
📋 核心要点
- 现有LLM强化学习方法因训练和推理阶段概率分布差异,导致训练不稳定,梯度估计偏差大。
- 提出动态词汇剪枝策略,约束强化学习目标在“安全”词汇表内,排除低概率token,降低训练-推理不匹配。
- 实验结果表明,该方法能有效稳定训练过程,理论分析也证明了词汇剪枝引入的优化偏差可控。
📝 摘要(中文)
大型语言模型(LLM)的强化学习面临一个根本性矛盾:高吞吐量推理引擎和数值精确的训练系统从相同的参数产生不同的概率分布,从而造成训练-推理不匹配。我们证明了这种不匹配具有不对称效应:log-概率不匹配的界限与(1-p)成比例,其中p是token概率。对于高概率token,这个界限消失,对序列级别的不匹配贡献很小。对于尾部低概率token,这个界限仍然很大,并且当被采样时,这些token表现出系统性偏差的不匹配,这些不匹配会在序列上累积,从而破坏梯度估计的稳定性。我们没有应用事后校正,而是建议将RL目标约束到一个动态剪枝的“安全”词汇表,该词汇表排除了极端的尾部。通过剪枝这些token,我们用小的、有界的优化偏差来换取大的、系统性偏差的不匹配。在经验上,我们的方法实现了稳定的训练;在理论上,我们限制了词汇剪枝引入的优化偏差。
🔬 方法详解
问题定义:现有LLM强化学习方法在训练过程中,由于高吞吐量推理引擎和数值精确的训练系统产生不同的概率分布,导致训练-推理不匹配。这种不匹配尤其体现在低概率token上,它们的不匹配偏差会累积,严重影响梯度估计的准确性,最终导致训练不稳定。现有方法缺乏对这种不匹配的有效控制。
核心思路:论文的核心思路是通过动态剪枝词汇表,移除那些低概率的token,从而限制强化学习的目标空间。这样做的目的是减少训练和推理阶段概率分布的差异,特别是避免低概率token带来的系统性偏差。通过牺牲少量优化偏差,换取训练过程的稳定性。
技术框架:该方法的核心在于动态词汇剪枝。具体来说,在强化学习训练的每个迭代步骤中,首先根据当前模型的概率分布,动态地确定一个“安全”词汇表,该词汇表排除了概率低于某个阈值的token。然后,强化学习的目标函数被限制在这个“安全”词汇表内,只对这些token进行优化。这样可以避免低概率token带来的梯度偏差,从而稳定训练过程。
关键创新:该方法最重要的创新点在于动态词汇剪枝策略。与传统的静态词汇表相比,动态剪枝能够根据模型的学习状态自适应地调整词汇表,从而更好地平衡训练的稳定性和优化效果。此外,该方法还提供了理论分析,证明了词汇剪枝引入的优化偏差是可控的。
关键设计:关键的设计包括:1) 动态词汇表的确定方法,需要选择合适的阈值来平衡剪枝带来的偏差和稳定性的提升;2) 强化学习目标的约束方式,需要确保在剪枝后的词汇表内仍然能够有效地进行策略优化;3) 理论分析中,需要对剪枝引入的优化偏差进行精确的量化和界定。
🖼️ 关键图片
📊 实验亮点
该方法通过动态词汇剪枝,有效稳定了LLM的强化学习训练过程。实验结果表明,该方法能够在保证模型性能的同时,显著降低训练过程中的不稳定性,并提供了优化偏差的理论上界。
🎯 应用场景
该研究成果可应用于各种需要利用强化学习微调大型语言模型的场景,例如对话系统、文本生成、代码生成等。通过稳定训练过程,可以提升模型的性能和可靠性,降低训练成本,加速LLM在实际应用中的部署。
📄 摘要(原文)
Reinforcement learning for large language models (LLMs) faces a fundamental tension: high-throughput inference engines and numerically-precise training systems produce different probability distributions from the same parameters, creating a training-inference mismatch. We prove this mismatch has an asymmetric effect: the bound on log-probability mismatch scales as $(1-p)$ where $p$ is the token probability. For high-probability tokens, this bound vanishes, contributing negligibly to sequence-level mismatch. For low-probability tokens in the tail, the bound remains large, and moreover, when sampled, these tokens exhibit systematically biased mismatches that accumulate over sequences, destabilizing gradient estimation. Rather than applying post-hoc corrections, we propose constraining the RL objective to a dynamically-pruned ``safe'' vocabulary that excludes the extreme tail. By pruning such tokens, we trade large, systematically biased mismatches for a small, bounded optimization bias. Empirically, our method achieves stable training; theoretically, we bound the optimization bias introduced by vocabulary pruning.