Training Large Language Models to Reason via EM Policy Gradient
作者: Tianbing Xu
分类: cs.LG, cs.AI, stat.ML
发布日期: 2025-04-24
💡 一句话要点
提出EM策略梯度算法,提升LLM在复杂推理任务中的性能与可解释性
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 大型语言模型 强化学习 策略梯度 期望最大化 推理能力 离线学习 认知行为
📋 核心要点
- 现有强化学习方法(如PPO、GRPO)在训练LLM推理能力时,依赖复杂的重要性权重和启发式裁剪,限制了其可扩展性和简洁性。
- EM策略梯度算法将推理任务建模为EM优化问题,通过交替采样推理轨迹和奖励引导微调,简化了训练过程,提升了推理性能。
- 实验表明,EM策略梯度在GSM8K和MATH数据集上达到或超过了SOTA水平,并使模型展现出子问题分解、自我验证等认知行为。
📝 摘要(中文)
本文提出了一种名为EM策略梯度(EM Policy Gradient)的离线强化学习算法,旨在通过优化推理轨迹上的期望回报来增强大型语言模型(LLM)的推理能力。该方法将推理任务建模为期望最大化(EM)优化问题,交替进行多样化推理轨迹的采样和奖励引导的微调。与依赖复杂重要性权重和启发式裁剪的PPO和GRPO不同,EM策略梯度提供了一种更简单、更具原则性的离线策略梯度方法,消除了这些复杂性,同时保持了强大的性能。在GSM8K和MATH(HARD)数据集上的评估表明,EM策略梯度实现了与最先进的GRPO相当或略微超过的性能,同时在可扩展性、简洁性和推理简洁性方面具有额外的优势。此外,使用该方法微调的模型表现出认知行为,例如子问题分解、自我验证和回溯,突显了其增强LLM推理的可解释性和鲁棒性的潜力。
🔬 方法详解
问题定义:现有方法如PPO和GRPO在利用强化学习训练LLM进行推理时,需要处理复杂的策略梯度估计问题,具体来说,它们依赖于重要性采样和启发式裁剪来稳定训练过程,这增加了算法的复杂性,并可能限制其可扩展性和训练效果。此外,这些方法在推理轨迹的探索方面可能存在不足,难以发现更优的推理路径。
核心思路:本文的核心思路是将LLM的推理过程建模为一个期望最大化(EM)问题。E步骤负责生成多样化的推理轨迹,M步骤则利用这些轨迹上的奖励信号来微调LLM的策略,从而优化期望回报。通过交替执行E步骤和M步骤,算法能够逐步提升LLM的推理能力。这种方法避免了直接计算复杂的重要性权重,简化了训练过程。
技术框架:EM策略梯度的整体框架包含两个主要阶段:E步骤(Expectation)和M步骤(Maximization)。在E步骤中,利用当前的LLM策略生成多个推理轨迹,并根据环境反馈(例如,答案是否正确)为每个轨迹分配奖励。在M步骤中,利用收集到的推理轨迹和对应的奖励,通过策略梯度方法微调LLM的参数,目标是最大化期望回报。这两个步骤交替进行,直到LLM的推理能力达到预定的目标。
关键创新:EM策略梯度算法的关键创新在于将强化学习问题转化为EM优化问题,从而避免了直接计算复杂的重要性权重。与传统的策略梯度方法相比,EM策略梯度提供了一种更简单、更具原则性的离线策略梯度方法,降低了算法的复杂性,并提高了训练的稳定性。此外,该方法能够更好地探索推理轨迹空间,发现更优的推理路径。
关键设计:EM策略梯度算法的关键设计包括:1) 推理轨迹的采样策略:采用多样化的采样策略,鼓励模型探索不同的推理路径。2) 奖励函数的设计:根据任务的特点设计合适的奖励函数,引导模型学习正确的推理步骤。3) 策略梯度更新方法:采用合适的策略梯度更新方法,例如Adam优化器,来微调LLM的参数。4) EM迭代的停止条件:设置合适的停止条件,例如达到预定的训练轮数或性能指标。
🖼️ 关键图片
📊 实验亮点
在GSM8K和MATH(HARD)数据集上的实验结果表明,EM策略梯度算法能够达到与最先进的GRPO算法相当或略微超过的性能。例如,在MATH(HARD)数据集上,使用EM策略梯度微调的LLM取得了显著的性能提升,并且展现出子问题分解、自我验证和回溯等认知行为,表明该方法能够有效提升LLM的推理能力和可解释性。
🎯 应用场景
EM策略梯度算法可广泛应用于需要复杂推理能力的场景,如数学问题求解、代码生成、科学研究、智能代理和虚拟助手等。通过提升LLM的推理能力和可解释性,该方法有望提高这些应用领域的性能和可靠性,并促进人工智能技术的进一步发展。
📄 摘要(原文)
Recently, foundation models such as OpenAI's O1 and O3, along with DeepSeek's R1, have demonstrated strong reasoning capacities and problem-solving skills acquired through large-scale reinforcement learning (RL), with wide applications in mathematics, coding, science, intelligent agents, and virtual assistants. In this work, we introduce an off-policy reinforcement learning algorithm, EM Policy Gradient, aimed at enhancing LLM reasoning by optimizing expected return over reasoning trajectories. We frame the reasoning task as an Expectation-Maximization (EM) optimization problem, alternating between sampling diverse rationale trajectories and performing reward-guided fine-tuning. Unlike PPO and GRPO, which rely on complex importance weights and heuristic clipping, our method provides a simpler, more principled off-policy policy gradient approach, eliminating these complexities while maintaining strong performance. We evaluate the effectiveness of EM Policy Gradient on the GSM8K and MATH (HARD) datasets, where it achieves performance comparable to or slightly surpassing the state-of-the-art GRPO, while offering additional advantages in scalability, simplicity, and reasoning conciseness. Moreover, models fine-tuned with our method exhibit cognitive behaviors, such as sub-problem decomposition, self-verification, and backtracking, highlighting its potential to enhance both the interpretability and robustness of LLM reasoning.