RLAX: Large-Scale, Distributed Reinforcement Learning for Large Language Models on TPUs

📄 arXiv: 2512.06392v2 📥 PDF

作者: Runlong Zhou, Lefan Zhang, Shang-Chen Wu, Kelvin Zou, Hanzhi Zhou, Ke Ye, Yihao Feng, Dong Yin, Alex Guillen Garcia, Dmytro Babych, Rohit Chatterjee, Matthew Hopkins, Xiang Kong, Chang Lan, Lezhi Li, Yiping Ma, Daniele Molinari, Senyu Tong, Yanchao Sun, Thomas Voice, Jianyu Wang, Chong Wang, Simon Wang, Floris Weers, Yechen Xu, Guolin Yin, Muyang Yu, Yi Zhang, Zheng Zhou, Danyang Zhuo, Ruoming Pang, Cheng Leong

分类: cs.LG, cs.AI

发布日期: 2025-12-06 (更新: 2025-12-11)

备注: The submission is being withdrawn because internal stakeholders determined that it is not appropriate to publish work on this topic at this time


💡 一句话要点

RLAX:用于大规模语言模型在TPU上的大规模分布式强化学习框架

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 强化学习 大型语言模型 分布式训练 TPU 参数服务器 推理能力 系统优化

📋 核心要点

  1. 现有大型语言模型推理能力提升面临扩展性和训练效率的挑战,尤其是在资源受限的环境下。
  2. RLAX通过参数服务器架构和一系列系统优化技术,实现了在TPU上可扩展且可抢占的强化学习训练。
  3. 实验表明,RLAX显著提升了大型语言模型的推理准确率,并在训练过程中表现出良好的鲁棒性。

📝 摘要(中文)

本文介绍RLAX,一个在TPU上可扩展的强化学习(RL)框架,旨在提升大型语言模型(LLM)的推理能力。RLAX采用参数服务器架构,主训练器定期将更新后的模型权重推送到参数服务器,而大量推理工作器拉取最新的权重并生成新的rollout。我们引入了一系列系统技术,以实现可扩展且可抢占的RL,适用于各种最先进的RL算法。为了加速收敛并提高模型质量,我们设计了新的数据集管理和对齐技术。大规模评估表明,RLAX在1024个v5p TPU上仅用12小时48分钟就将QwQ-32B的pass@8准确率提高了12.8%,同时在训练期间保持了对抢占的鲁棒性。

🔬 方法详解

问题定义:现有方法在利用强化学习提升大型语言模型推理能力时,面临着训练规模大、计算资源需求高、训练过程易受中断影响等问题。尤其是在TPU等专用硬件上进行大规模分布式训练时,如何高效地利用资源、加速收敛、保证训练的稳定性是一个巨大的挑战。

核心思路:RLAX的核心思路是采用参数服务器架构,将模型参数存储在中心化的参数服务器上,训练器定期推送更新,推理工作器拉取最新参数。这种架构能够有效地解耦训练和推理过程,提高训练效率和可扩展性。同时,通过一系列系统优化技术,如数据并行、模型并行、流水线并行等,进一步提升训练速度和资源利用率。

技术框架:RLAX的整体架构包括以下几个主要模块:1) 主训练器:负责模型的训练和参数更新;2) 参数服务器:存储模型的参数,并提供参数的推送和拉取服务;3) 推理工作器:负责生成新的rollout,用于训练;4) 数据集管理模块:负责数据集的构建、清洗和对齐。训练流程大致如下:主训练器从参数服务器拉取最新的模型参数,利用rollout数据进行训练,然后将更新后的参数推送回参数服务器。推理工作器定期从参数服务器拉取最新的模型参数,生成新的rollout数据。

关键创新:RLAX的关键创新在于其系统层面的优化,包括:1) 可扩展的参数服务器架构,能够支持大规模分布式训练;2) 针对TPU的优化,充分利用TPU的计算能力;3) 可抢占的训练机制,保证训练的鲁棒性;4) 新的数据集管理和对齐技术,加速收敛并提高模型质量。

关键设计:RLAX的关键设计包括:1) 参数服务器的存储和通信机制,需要保证高效的参数推送和拉取;2) 损失函数的设计,需要能够有效地指导模型的训练;3) 数据集的构建和清洗,需要保证数据的质量和多样性;4) 训练过程中的超参数调整,需要根据具体的任务进行优化。

🖼️ 关键图片

img_0

📊 实验亮点

RLAX在1024个v5p TPU上仅用12小时48分钟就将QwQ-32B的pass@8准确率提高了12.8%,证明了其在大规模分布式强化学习方面的有效性。此外,RLAX在训练过程中表现出良好的鲁棒性,能够抵抗抢占的影响,保证训练的顺利进行。

🎯 应用场景

RLAX框架可广泛应用于各种需要提升推理能力的大型语言模型,例如代码生成、数学问题求解、文本摘要等。该框架能够加速模型训练,提高模型性能,并降低训练成本。未来,RLAX有望成为构建更强大、更智能的AI系统的关键基础设施。

📄 摘要(原文)

Reinforcement learning (RL) has emerged as the de-facto paradigm for improving the reasoning capabilities of large language models (LLMs). We have developed RLAX, a scalable RL framework on TPUs. RLAX employs a parameter-server architecture. A master trainer periodically pushes updated model weights to the parameter server while a fleet of inference workers pull the latest weights and generates new rollouts. We introduce a suite of system techniques to enable scalable and preemptible RL for a diverse set of state-of-art RL algorithms. To accelerate convergence and improve model quality, we have devised new dataset curation and alignment techniques. Large-scale evaluations show that RLAX improves QwQ-32B's pass@8 accuracy by 12.8% in just 12 hours 48 minutes on 1024 v5p TPUs, while remaining robust to preemptions during training.