FP8-RL: A Practical and Stable Low-Precision Stack for LLM Reinforcement Learning
作者: Zhaopeng Qiu, Shuang Yu, Jingqi Zhang, Shuai Zhang, Xue Huang, Jingyi Yang, Junjie Lai
分类: cs.LG, cs.CL
发布日期: 2026-01-26
💡 一句话要点
提出FP8-RL,通过低精度推理加速LLM强化学习并保持训练稳定性。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: FP8量化 强化学习 大型语言模型 Rollout加速 重要性采样
📋 核心要点
- LLM强化学习受限于rollout阶段的计算和内存瓶颈,尤其是在处理长序列时,attention机制和KV-cache消耗大量资源。
- 论文提出FP8-RL方案,通过FP8量化降低计算和内存需求,并采用重要性采样校正rollout偏差,保证训练稳定性。
- 实验结果表明,该方案在稠密和MoE模型上实现了高达44%的rollout吞吐量提升,同时保持了与BF16基线相当的学习性能。
📝 摘要(中文)
针对大型语言模型(LLM)强化学习中rollout阶段因长序列导致attention和KV-cache内存成为瓶颈的问题,本文提出了一种实用的FP8 rollout方案。该方案通过降低计算成本和内存流量来加速RL。然而,在RL中应用FP8面临独特的挑战:策略权重每步都在变化,需要重复量化和同步权重到推理引擎;低精度rollout可能偏离训练器所假定的高精度策略,导致训练-推理不匹配和潜在的不稳定性。本文在veRL生态系统中实现了该方案,支持常见的训练后端(如FSDP/Megatron-LM)和推理引擎(如vLLM/SGLang)。具体而言,我们(i)通过分块FP8量化实现FP8 W8A8线性层rollout,(ii)通过每步QKV尺度重新校准将FP8扩展到KV-cache,消除长上下文内存瓶颈,以及(iii)使用基于重要性采样的rollout校正(token-level TIS/MIS变体)来缓解不匹配问题。在稠密和MoE模型上,这些技术在保持与BF16基线相当的学习行为的同时,提供了高达44%的rollout吞吐量增益。
🔬 方法详解
问题定义:大型语言模型(LLM)的强化学习(RL)过程中,rollout阶段,特别是生成长序列时,计算量和内存需求巨大。Attention机制和KV-cache成为性能瓶颈。现有的高精度(如BF16)计算方式虽然保证了训练的稳定性,但计算成本高昂,限制了rollout的效率。
核心思路:论文的核心思路是利用FP8低精度计算来降低rollout阶段的计算和内存需求,从而加速LLM的强化学习过程。为了解决低精度计算带来的训练-推理不匹配问题,引入了基于重要性采样的rollout校正方法。
技术框架:整体框架基于veRL生态系统,支持常见的训练后端(如FSDP/Megatron-LM)和推理引擎(如vLLM/SGLang)。主要包含三个模块:(1) FP8 W8A8线性层rollout,使用分块FP8量化;(2) FP8 KV-cache,通过每步QKV尺度重新校准消除长上下文内存瓶颈;(3) 基于重要性采样的rollout校正(token-level TIS/MIS变体),缓解训练-推理不匹配。
关键创新:最重要的技术创新点在于将FP8低精度计算应用于LLM的强化学习rollout阶段,并提出了一套完整的解决方案来解决由此带来的训练稳定性和训练-推理一致性问题。与现有方法相比,该方法能够在保证学习性能的同时,显著提升rollout的吞吐量。
关键设计:(1) 分块FP8量化:将权重和激活值分成块,分别进行量化,以减少量化误差。(2) 每步QKV尺度重新校准:动态调整KV-cache的量化尺度,以适应不同上下文的信息。(3) Token-level TIS/MIS:在token级别进行重要性采样,更精细地校正rollout偏差。具体参数设置和损失函数细节未知。
📊 实验亮点
实验结果表明,在稠密和MoE模型上,FP8-RL方案在保持与BF16基线相当的学习行为的同时,实现了高达44%的rollout吞吐量增益。这表明该方案在加速LLM强化学习方面具有显著的优势。
🎯 应用场景
该研究成果可广泛应用于需要快速rollout的大型语言模型强化学习任务中,例如对话系统、文本生成、智能助手等。通过降低计算成本和内存需求,可以支持更大规模的模型训练和更复杂的任务,加速LLM在实际应用中的部署和迭代。
📄 摘要(原文)
Reinforcement learning (RL) for large language models (LLMs) is increasingly bottlenecked by rollout (generation), where long output sequence lengths make attention and KV-cache memory dominate end-to-end step time. FP8 offers an attractive lever for accelerating RL by reducing compute cost and memory traffic during rollout, but applying FP8 in RL introduces unique engineering and algorithmic challenges: policy weights change every step (requiring repeated quantization and weight synchronization into the inference engine) and low-precision rollouts can deviate from the higher-precision policy assumed by the trainer, causing train-inference mismatch and potential instability. This report presents a practical FP8 rollout stack for LLM RL, implemented in the veRL ecosystem with support for common training backends (e.g., FSDP/Megatron-LM) and inference engines (e.g., vLLM/SGLang). We (i) enable FP8 W8A8 linear-layer rollout using blockwise FP8 quantization, (ii) extend FP8 to KV-cache to remove long-context memory bottlenecks via per-step QKV scale recalibration, and (iii) mitigate mismatch using importance-sampling-based rollout correction (token-level TIS/MIS variants). Across dense and MoE models, these techniques deliver up to 44% rollout throughput gains while preserving learning behavior comparable to BF16 baselines.