Accelerating RL for LLM Reasoning with Optimal Advantage Regression

📄 arXiv: 2505.20686v1 📥 PDF

作者: Kianté Brantley, Mingyu Chen, Zhaolin Gao, Jason D. Lee, Wen Sun, Wenhao Zhan, Xuezhou Zhang

分类: cs.LG, cs.AI

发布日期: 2025-05-27

🔗 代码/项目: GITHUB


💡 一句话要点

提出A*-PO算法,通过最优优势回归加速LLM推理的强化学习训练。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 强化学习 大型语言模型 推理能力 策略优化 优势函数

📋 核心要点

  1. 现有强化学习方法微调LLM推理能力时,计算开销大、内存消耗高,主要因为需要多次生成和依赖复杂的价值估计。
  2. A*-PO通过两阶段优化,首先离线估计最优价值函数,然后在线进行策略更新,直接近似最优优势函数。
  3. 实验表明,A*-PO在数学推理任务上表现出色,训练时间缩短至2倍,内存占用减少30%,优于PPO等方法。

📝 摘要(中文)

强化学习(RL)已成为微调大型语言模型(LLM)以提升复杂推理能力的强大工具。然而,当前最优策略优化方法通常面临高计算开销和内存消耗的问题,这主要是由于每个提示需要多次生成以及依赖于评论家网络或当前策略的优势估计。本文提出了A-PO,一种新颖的两阶段策略优化框架,它直接近似最优优势函数,从而能够高效地训练LLM进行推理任务。在第一阶段,我们利用来自参考策略的离线采样来估计最优价值函数V,消除了对昂贵的在线价值估计的需求。在第二阶段,我们使用简单的最小二乘回归损失执行在线策略更新,每个提示仅需一次生成。理论上,我们建立了性能保证,并证明了KL正则化的RL目标可以在不需要复杂探索策略的情况下进行优化。实验表明,A-PO在各种数学推理基准测试中取得了具有竞争力的性能,同时与PPO、GRPO和REBEL相比,训练时间最多减少2倍,峰值内存使用量减少30%以上。A-PO的实现可在https://github.com/ZhaolinGao/A-PO 找到。

🔬 方法详解

问题定义:现有基于强化学习的LLM推理能力提升方法,如PPO,存在计算开销大和内存消耗高的问题。这些方法通常需要对每个prompt进行多次生成,并依赖于复杂的评论家网络或优势函数估计,导致训练效率低下。

核心思路:A*-PO的核心思路是通过直接近似最优优势函数来加速强化学习训练过程。它避免了传统方法中对当前策略进行多次采样和价值估计的需要,转而利用离线数据估计最优价值函数,从而简化了在线策略更新过程。

技术框架:A-PO包含两个主要阶段:1) 离线价值估计阶段:利用参考策略的离线采样数据,通过回归方法估计最优价值函数V。这一阶段消除了在线价值估计的需要。2) 在线策略更新阶段:使用简单的最小二乘回归损失,基于V*进行on-policy更新。每个prompt只需要一次生成,显著降低了计算成本。

关键创新:A-PO的关键创新在于直接近似最优优势函数,避免了对当前策略的复杂价值估计。通过离线估计最优价值函数,A-PO能够更有效地进行策略优化,并减少了对复杂探索策略的依赖。

关键设计:A-PO的关键设计包括:1) 使用离线数据估计最优价值函数V。2) 使用最小二乘回归损失进行策略更新。3) 采用KL散度正则化,以保证策略更新的稳定性。具体的网络结构和参数设置取决于具体的LLM和推理任务。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

A-PO在多个数学推理基准测试中表现出竞争力,与PPO、GRPO和REBEL等基线方法相比,训练时间最多减少2倍,峰值内存使用量减少30%以上。这些结果表明A-PO在加速LLM推理能力训练方面具有显著优势。

🎯 应用场景

A*-PO算法可应用于各种需要复杂推理能力的大型语言模型,例如数学问题求解、代码生成、逻辑推理等。该方法能够显著降低训练成本,提高模型性能,加速LLM在实际场景中的部署和应用,具有广泛的应用前景。

📄 摘要(原文)

Reinforcement learning (RL) has emerged as a powerful tool for fine-tuning large language models (LLMs) to improve complex reasoning abilities. However, state-of-the-art policy optimization methods often suffer from high computational overhead and memory consumption, primarily due to the need for multiple generations per prompt and the reliance on critic networks or advantage estimates of the current policy. In this paper, we propose $A$-PO, a novel two-stage policy optimization framework that directly approximates the optimal advantage function and enables efficient training of LLMs for reasoning tasks. In the first stage, we leverage offline sampling from a reference policy to estimate the optimal value function $V$, eliminating the need for costly online value estimation. In the second stage, we perform on-policy updates using a simple least-squares regression loss with only a single generation per prompt. Theoretically, we establish performance guarantees and prove that the KL-regularized RL objective can be optimized without requiring complex exploration strategies. Empirically, $A$-PO achieves competitive performance across a wide range of mathematical reasoning benchmarks, while reducing training time by up to 2$\times$ and peak memory usage by over 30% compared to PPO, GRPO, and REBEL. Implementation of $A$-PO can be found at https://github.com/ZhaolinGao/A-PO.