Stabilizing MoE Reinforcement Learning by Aligning Training and Inference Routers
作者: Wenhan Ma, Hailin Zhang, Liang Zhao, Yifan Song, Yudong Wang, Zhifang Sui, Fuli Luo
分类: cs.CL, cs.AI, cs.LG
发布日期: 2025-10-13 (更新: 2025-10-21)
💡 一句话要点
提出R3方法,对齐MoE强化学习训练与推理路由,稳定训练过程。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 混合专家模型 强化学习 路由机制 训练推理一致性 策略优化
📋 核心要点
- MoE模型在强化学习中面临训练不稳定问题,特别是路由机制导致训练与推理阶段行为不一致。
- 论文提出Rollout Routing Replay (R3)方法,通过记录推理时的路由分布并在训练时重放,来对齐训练和推理。
- 实验表明R3能有效稳定RL训练,防止崩溃,并在性能上优于现有方法,同时保持训练速度。
📝 摘要(中文)
强化学习(RL)已成为增强大型语言模型能力的关键方法。然而,在混合专家(MoE)模型中,路由机制常常引入不稳定性,甚至导致灾难性的RL训练崩溃。我们分析了MoE模型的训练-推理一致性,并发现这两个阶段的路由行为存在显著差异。此外,即使在相同的条件下,路由框架也会在重复的前向传递中产生不同的专家选择。为了解决这种根本的不一致性,我们提出Rollout Routing Replay (R3),该方法记录来自推理引擎的路由分布,并在训练期间重放它们。R3显著降低了训练-推理策略的KL散度,并减轻了极端差异,而不会影响训练速度。在各种设置下进行的大量实验证实,R3成功地稳定了RL训练,防止了崩溃,并且优于GSPO和TIS等方法。我们相信这项工作可以为稳定MoE模型中的RL提供一种新的解决方案。
🔬 方法详解
问题定义:MoE模型在强化学习训练中,由于路由机制的不稳定性,容易出现训练崩溃的问题。现有方法未能有效解决训练和推理阶段路由行为的不一致性,导致策略学习不稳定。即使输入相同,MoE的路由选择也可能因随机性而不同,加剧了这一问题。
核心思路:R3的核心思路是通过模仿推理阶段的路由行为来稳定训练过程。具体来说,R3记录推理过程中的路由分布,并在训练过程中重放这些分布,从而使训练过程中的路由选择更接近推理时的路由选择,减少训练和推理之间的差异。
技术框架:R3方法主要包含以下步骤:1) 在推理阶段,记录每个token的路由分布(即每个专家被选择的概率)。2) 在训练阶段,从经验回放缓冲区中采样数据。3) 使用记录的路由分布作为目标,通过最小化训练时的路由分布与目标路由分布之间的差异来更新路由策略。整体框架是在标准的强化学习训练流程中加入路由分布的记录和重放机制。
关键创新:R3的关键创新在于显式地对齐了训练和推理阶段的路由分布。与现有方法不同,R3直接干预路由过程,使其在训练时模仿推理时的行为,从而减少了策略学习的不确定性。这种方法避免了对策略梯度进行复杂的修改,而是从根本上解决了路由不一致的问题。
关键设计:R3的关键设计包括:1) 使用KL散度作为损失函数,衡量训练时路由分布与目标路由分布之间的差异。2) 引入一个超参数来控制路由分布重放的强度,平衡模仿学习和探索之间的关系。3) 在经验回放缓冲区中存储路由分布,以便在训练时进行重放。具体实现上,需要修改MoE模型的路由层,使其能够记录和重放路由分布。
🖼️ 关键图片
📊 实验亮点
实验结果表明,R3方法在多个MoE强化学习任务中成功稳定了训练过程,有效防止了训练崩溃。与GSPO和TIS等基线方法相比,R3在性能上取得了显著提升,并且没有牺牲训练速度。具体性能提升幅度取决于具体的任务和模型设置,但总体趋势是R3能够显著改善MoE强化学习的稳定性和性能。
🎯 应用场景
该研究成果可广泛应用于各种基于MoE的强化学习任务中,尤其是在需要稳定训练过程的复杂环境中。例如,可以应用于大型语言模型的指令微调、对话系统训练、以及机器人控制等领域,提升模型的性能和鲁棒性,降低训练成本。
📄 摘要(原文)
Reinforcement learning (RL) has emerged as a crucial approach for enhancing the capabilities of large language models. However, in Mixture-of-Experts (MoE) models, the routing mechanism often introduces instability, even leading to catastrophic RL training collapse. We analyze the training-inference consistency of MoE models and identify a notable discrepancy in routing behaviors between the two phases. Moreover, even under identical conditions, the routing framework can yield divergent expert selections across repeated forward passes. To address this foundational inconsistency, we propose Rollout Routing Replay (R3), a method that records routing distributions from the inference engine and replays them during training. R3 significantly reduces training-inference policy KL divergence and mitigates extreme discrepancies without compromising training speed. Extensive experiments on various settings confirm that R3 succeeds in stabilizing RL training, preventing collapse and outperforming methods such as GSPO and TIS. We believe this work can offer a new solution for stabilizing RL in MoE models.