Learn Where Outcomes Diverge: Efficient VLA RL via Probabilistic Chunk Masking

📄 arXiv: 2605.16154v1 📥 PDF

作者: Vaidehi Bagaria, Nikshep Grampurohit, Pulkit Verma

分类: cs.LG, cs.RO

发布日期: 2026-05-15


💡 一句话要点

提出概率块掩码(PCM)加速VLA强化学习,提升梯度计算效率。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 视觉语言动作 强化学习 梯度计算 概率掩码 计算效率

📋 核心要点

  1. VLA强化学习中,梯度计算是主要瓶颈,现有方法对所有轨迹块均匀计算梯度,效率低下。
  2. 提出概率块掩码(PCM),通过成功-失败动作方差估计梯度方差,有选择地计算梯度。
  3. 实验表明,PCM在保持成功率的同时,显著提升了计算速度和内存效率,加速了VLA强化学习。

📝 摘要(中文)

强化学习(RL)通过直接优化任务成功率,使视觉-语言-动作(VLA)策略能够泛化到训练分布之外,但后训练计算成本高昂。虽然可以通过更快的模拟器和世界模型来加速rollout收集,但在基于GRPO的VLA RL中,梯度计算占据了每步wall-clock时间的约78%,而rollout收集仅占21%。梯度成本占主导地位是因为大部分计算花费在对学习贡献很小的阶段。GRPO的学习信号由优势方差驱动:只有成功和失败rollout发散的阶段才会产生学习信号。然而,GRPO为rollout中的每个块分配相同的优势。因此,actor更新计算均匀地分布在整个轨迹上,包括策略在预训练和监督微调后已经处理好的阶段。本文提出了概率块掩码(PCM),这是一种GRPO的即插即用修改,它将梯度计算分配给每个轨迹中一小部分概率选择的块子集。PCM使用成功-失败动作方差对语义阶段进行评分,这是一种rollout导出的每阶段梯度方差代理,并使用在线更新的阶段级别保持概率来采样固定块预算。我们将每阶段梯度方差形式化为确定梯度计算有用性的量,并表明成功-失败动作方差为其提供了一个可测量的代理。PCM不需要奖励模型或学习到的评论家。在三个LIBERO基准测试中,PCM匹配了标准GRPO的最终成功率,同时实现了2.38倍的wall-clock加速,4.8倍更快的梯度更新和60%更低的峰值激活内存,同时反向传播通过少于20%的轨迹块。

🔬 方法详解

问题定义:论文旨在解决视觉-语言-动作(VLA)强化学习中梯度计算效率低下的问题。现有方法,如GRPO,在更新actor时,对轨迹中的每个chunk都进行梯度计算,而实际上只有成功和失败的rollout出现差异的chunk才对学习有贡献。这导致大量的计算资源被浪费在对学习无益的阶段,成为VLA强化学习的瓶颈。

核心思路:论文的核心思路是只对那些对学习有贡献的轨迹chunk进行梯度计算。通过估计每个chunk的梯度方差,并优先选择梯度方差大的chunk进行计算,从而提高计算效率。具体来说,论文使用成功-失败动作方差作为每阶段梯度方差的代理指标,并根据该指标对chunk进行概率采样。

技术框架:PCM是GRPO的改进版本,整体框架与GRPO类似,主要包括以下几个阶段: 1. Rollout收集:使用当前的策略与环境交互,收集轨迹数据。 2. 优势计算:计算每个chunk的优势函数,用于指导策略更新。 3. 概率块掩码(PCM):根据成功-失败动作方差,计算每个chunk的保持概率,并根据该概率对chunk进行采样。 4. 梯度计算与更新:只对采样到的chunk进行梯度计算,并更新actor网络。

关键创新:论文最重要的创新点在于提出了使用成功-失败动作方差作为每阶段梯度方差的代理指标。与现有方法不同,PCM不是对所有chunk进行均匀的梯度计算,而是根据chunk的重要性进行选择性计算,从而显著提高了计算效率。此外,PCM不需要额外的奖励模型或学习到的评论家,易于集成到现有的VLA强化学习框架中。

关键设计: * 成功-失败动作方差:使用成功和失败rollout中动作的方差来估计每阶段的梯度方差。方差越大,说明该阶段对学习越重要。 * 在线更新保持概率:根据成功-失败动作方差,在线更新每个chunk的保持概率。保持概率高的chunk更容易被采样到。 * 固定块预算:为了控制计算成本,论文设置了一个固定的块预算,即每次只采样固定数量的chunk进行梯度计算。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,PCM在三个LIBERO基准测试中,能够在匹配标准GRPO最终成功率的同时,实现2.38倍的wall-clock加速,4.8倍更快的梯度更新,以及60%更低的峰值激活内存。此外,PCM仅需反向传播通过少于20%的轨迹块,证明了其高效性。

🎯 应用场景

该研究成果可应用于各种需要高效VLA强化学习的场景,例如机器人导航、任务规划、游戏AI等。通过减少梯度计算量,可以加速模型的训练过程,降低计算资源消耗,从而使得VLA强化学习能够应用于更复杂的任务和更广泛的领域。未来,该方法有望推动机器人自主学习和人机协作等领域的发展。

📄 摘要(原文)

Reinforcement learning (RL) allows vision-language-action (VLA) policies to generalize beyond their training distribution by optimizing directly for task success, but post-training is computationally expensive. A natural response has been to speed rollout collection through faster simulators and world models. In GRPO-based VLA RL, we find that the dominant cost lies elsewhere: gradient computation accounts for approximately 78% of wall-clock time per step in our runs, while rollout collection accounts for only 21%. Gradient cost dominates because much of this computation is spent on phases that contribute little to learning. GRPO's learning signal is driven by advantage variance: only phases where successful and failed rollouts diverge produce learning signal. However, GRPO assigns the same advantage to every chunk in a rollout. As a result, actor-update compute is spent uniformly across the trajectory, including phases the policy already handles after pre-training and supervised fine-tuning. This paper presents Probabilistic Chunk Masking (PCM), a drop-in modification to GRPO that allocates gradient computation to a small, probabilistically selected subset of chunks per trajectory. PCM scores semantic phases using success-failure action variance, a rollout-derived proxy for per-phase gradient variance, and samples a fixed chunk budget with online-updated phase-level keep probabilities. We formalize per-phase gradient variance as the quantity determines where gradient computation is useful and show that success-failure action variance provides a measurable proxy for it. PCM requires no reward model or learned critic. On three LIBERO benchmarks, PCM matches the final success rate of standard GRPO while achieving 2.38 times wall-clock speedup, 4.8 times faster gradient updates, and 60% lower peak activation memory, while backpropagating through fewer than 20% of trajectory chunks.