How to Compress KV Cache in RL Post-Training? Shadow Mask Distillation for Memory-Efficient Alignment
作者: Rui Zhu, Weiheng Bai, Qiushi Wu, Yang Ren, Haixu Tang, Yuchu Liu
分类: cs.LG, cs.AI
发布日期: 2026-05-07
💡 一句话要点
提出影子掩码蒸馏(Shadow Mask Distillation)方法,解决强化学习后训练中KV缓存压缩导致的策略偏差问题。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 大语言模型 强化学习 KV缓存压缩 模型蒸馏 长上下文推理 策略优化
📋 核心要点
- 核心问题:在线RL训练中,KV缓存压缩引入的微小近似误差在策略优化过程中被指数级放大,导致严重的离策略偏差与训练不稳定。
- 方法要点:提出影子掩码蒸馏(Shadow Mask Distillation),通过在训练过程中对齐压缩模型与全精度模型的行为,消除压缩带来的分布偏移。
- 实验或效果:该方法有效缓解了长上下文任务中的内存瓶颈,在保持训练稳定性的同时,显著提升了RL后训练的样本效率与最终模型性能。
📝 摘要(中文)
强化学习(RL)已成为解锁大语言模型(LLM)高级推理能力的关键范式,涵盖RLHF和RLAIF等框架。无论采用何种优化算法(如PPO、GRPO或在线DPO),在线RL本质上都需要探索性的轨迹生成(rollout)阶段。然而,对于长上下文推理任务,该阶段因巨大的键值(KV)缓存占用而面临严重的“内存墙”问题。虽然在rollout期间应用KV缓存压缩可以缓解内存开销,但它会引入严重的离策略(off-policy)偏差。尽管现代KV压缩在标准推理中几乎无损,但RL优化固有的不稳定性会放大微小的近似误差。具体而言,采样器在稀疏上下文下生成响应,而学习器则使用完整的稠密上下文更新参数。现有的统计解决方案(如重要性采样重加权)难以纠正这种放大的偏差,且存在梯度方差大和样本效率低的问题。
🔬 方法详解
问题定义:论文旨在解决长上下文LLM在在线RL训练中,因KV缓存压缩导致的“rollout阶段稀疏上下文”与“训练阶段稠密上下文”之间的分布不一致问题,这种偏差会严重破坏策略梯度的估计。
核心思路:论文提出影子掩码蒸馏(Shadow Mask Distillation),其核心思想是将压缩后的模型视为“影子”,通过蒸馏机制强制其输出分布与全精度模型在相同上下文下的行为保持一致,从而在压缩内存的同时消除偏差。
技术框架:整体框架包含两个并行路径:一个是使用压缩KV缓存的采样器(Shadow),另一个是使用全精度KV缓存的参考模型(Teacher)。通过计算两者在轨迹生成过程中的KL散度或掩码分布差异,将蒸馏损失引入到RL的优化目标中。
关键创新:最重要的创新在于将KV缓存的压缩决策(Mask)视为可学习的蒸馏目标,而非简单的启发式剪枝。这种方法使得压缩模型能够学习到在RL探索中对决策至关重要的关键信息,从而在低内存占用下维持高保真度。
关键设计:关键设计包括引入影子掩码损失函数,该函数动态惩罚压缩模型与全精度模型在注意力权重分布上的差异。此外,通过引入自适应的掩码阈值调整机制,确保在不同推理阶段平衡内存压缩率与策略一致性。
🖼️ 关键图片
📊 实验亮点
实验表明,影子掩码蒸馏在长上下文任务中实现了显著的内存节省(最高可达数倍),同时在保持训练稳定性的前提下,相比传统KV压缩方法,在推理准确率和奖励得分上均有显著提升,有效解决了长序列RL训练中的样本效率低下问题。
🎯 应用场景
该研究主要应用于长文本大模型的强化学习后训练阶段,特别是在处理超长上下文推理、复杂逻辑链(CoT)生成等对内存要求极高的场景。它为在消费级硬件或受限计算资源下进行大规模模型对齐提供了技术路径,具有极高的工业落地价值。
📄 摘要(原文)
Reinforcement Learning (RL) has emerged as a crucial paradigm for unlocking the advanced reasoning capabilities of Large Language Models (LLMs), encompassing frameworks like RLHF and RLAIF. Regardless of the specific optimization algorithm (e.g., PPO, GRPO, or Online DPO), online RL inherently requires an exploratory trajectory generation (rollout) phase. However, for long-context reasoning tasks, this rollout phase imposes a severe ``memory wall'' due to the exorbitant Key-Value (KV) cache footprint. While applying KV cache compression during rollouts mitigates this memory overhead, it induces a critical off-policy bias. Although modern KV compression is often nearly lossless during standard inference, even minuscule approximation errors are drastically amplified by the inherent instability of RL optimization. Specifically, the sampler generates responses under a sparse context, whereas the learner updates parameters using the full, dense context. Existing statistical solutions, such as importance reweighting, struggle to correct this magnified bias, suffering from high gradient variance and severe sample inefficiency.