How Reinforcement Learning After Next-Token Prediction Facilitates Learning

📄 arXiv: 2510.11495v2 📥 PDF

作者: Nikolaos Tsilivis, Eran Malach, Karen Ullrich, Julia Kempe

分类: cs.LG, stat.ML

发布日期: 2025-10-13 (更新: 2025-12-16)


💡 一句话要点

提出强化学习后接续预测框架,提升LLM在推理任务中的泛化能力

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 强化学习 语言模型 接续预测 思维链 泛化能力 推理任务 自回归模型

📋 核心要点

  1. 现有方法在推理领域依赖大规模语言模型,但单纯的接续预测在复杂任务上泛化能力不足。
  2. 论文提出强化学习后接续预测的训练框架,利用强化学习优化语言模型,提升其在推理任务中的性能。
  3. 实验证明,该方法在奇偶校验预测和数学推理基准测试中,显著提升了模型的泛化能力和推理效果。

📝 摘要(中文)

本文提出了一个框架,用于研究大型语言模型(LLM)在序列预测后通过强化学习算法进行优化这一范式的成功之处。从理论上揭示了强化学习在这种设置下优于接续预测的优化机制。研究了由短链和长链“思维链”序列混合分布中学习单个任务的情况。特别地,当任务是预测d位奇偶校验且长序列稀少时,表明强化学习在接续预测之后能够使自回归Transformer泛化,而单纯的接续预测需要极端的统计或计算资源才能做到这一点。进一步解释了强化学习如何利用增加的测试时计算(表现为更长的响应)来促进这一学习过程。在一个简化的设置中,从理论上证明,只要数据混合中长演示的比例不以输入维度d呈指数级地小,遵循这种训练方案的自回归线性模型就可以有效地学习预测d位奇偶校验。最后,在其他设置中也展示了相同的现象,包括在常见数学推理基准的混合变体上对Llama系列模型进行后训练。

🔬 方法详解

问题定义:论文旨在解决大型语言模型在复杂推理任务中,仅通过接续预测训练难以泛化的问题。现有方法,即单纯的接续预测,在长序列和稀疏奖励的情况下,需要大量的计算资源和数据才能达到较好的性能,且容易过拟合短序列。

核心思路:论文的核心思路是利用强化学习(RL)在接续预测(Next-Token Prediction, NTP)之后对模型进行微调。通过RL,模型可以更好地探索奖励空间,学习到更有效的策略,从而提高在推理任务中的泛化能力。这种方法允许模型在测试时进行更长的推理链,从而更好地利用计算资源。

技术框架:整体框架包含两个主要阶段:1) 预训练阶段:使用标准的接续预测目标训练大型语言模型。2) 强化学习微调阶段:使用RL算法(如策略梯度)对预训练模型进行微调,目标是最大化任务奖励。在推理阶段,模型可以生成更长的序列,从而进行更复杂的推理。

关键创新:论文的关键创新在于将强化学习与接续预测相结合,提出了一种有效的训练范式。这种方法能够克服单纯接续预测的局限性,使模型能够更好地利用测试时的计算资源,从而提高泛化能力。此外,论文还从理论上分析了强化学习在这种训练范式下的优化机制。

关键设计:论文中,关键的设计包括:1) 混合数据集:使用包含短链和长链“思维链”序列的混合数据集进行训练。2) 奖励函数:根据任务目标设计合适的奖励函数,例如,在奇偶校验预测任务中,只有预测正确时才给予奖励。3) 模型结构:使用自回归Transformer模型作为基础模型。4) 强化学习算法:可以使用各种策略梯度算法,如REINFORCE或PPO。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

论文通过实验证明,在奇偶校验预测任务中,强化学习后接续预测的方法能够显著提升模型的泛化能力,尤其是在长序列稀少的情况下。此外,在Llama系列模型上进行的数学推理实验也表明,该方法能够有效提升模型在常见数学推理基准上的性能。

🎯 应用场景

该研究成果可应用于各种需要复杂推理能力的场景,例如数学问题求解、代码生成、对话系统等。通过强化学习的微调,可以显著提升语言模型在这些领域的性能,使其能够更好地理解和解决复杂问题。未来,该方法有望推动人工智能在更广泛领域的应用。

📄 摘要(原文)

Recent advances in reasoning domains with neural networks have primarily been enabled by a training recipe that optimizes Large Language Models, previously trained to predict the next-token in a sequence, with reinforcement learning algorithms. We introduce a framework to study the success of this paradigm, and we theoretically expose the optimization mechanisms by which reinforcement learning improves over next-token prediction in this setting. We study learning from mixture distributions of short and long ``chain-of-thought'' sequences encoding a single task. In particular, when the task consists of predicting the parity of $d$ bits and long sequences are rare, we show how reinforcement learning after next-token prediction enables autoregressive transformers to generalize, whereas mere next-token prediction requires extreme statistical or computational resources to do so. We further explain how reinforcement learning leverages increased test-time computation, manifested in longer responses, to facilitate this learning process. In a simplified setting, we theoretically prove that autoregressive linear models following this training recipe can efficiently learn to predict the parity of $d$ bits as long as the proportion of long demonstrations in the data mix is not exponentially small in the input dimension $d$. Finally, we demonstrate these same phenomena in other settings, including the post-training of Llama-series models on mixture variations of common mathematical reasoning benchmarks.