Learning Long-Context Diffusion Policies via Past-Token Prediction
作者: Marcel Torne, Andy Tang, Yuejiang Liu, Chelsea Finn
分类: cs.RO, cs.AI, cs.LG
发布日期: 2025-05-14 (更新: 2025-05-19)
备注: Videos are available at https://long-context-dp.github.io
💡 一句话要点
提出Past-Token Prediction,解决长序列机器人任务中Diffusion策略的上下文依赖问题。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 长上下文策略 Diffusion模型 机器人学习 模仿学习 时间建模 Past-Token Prediction 序列预测
📋 核心要点
- 现有长上下文策略学习方法,如截断历史信息,导致策略无法捕捉过去和未来动作间的依赖关系。
- 提出Past-Token Prediction (PTP)辅助任务,通过预测过去动作token来正则化策略,增强时间建模能力。
- 采用多阶段训练策略,预训练视觉编码器,微调策略头,并结合测试时自我验证机制,提升性能。
📝 摘要(中文)
在许多机器人任务中,对长序列的观察和动作进行推理至关重要。然而,从演示中学习有效的长上下文策略仍然具有挑战性。随着上下文长度的增加,由于内存需求的增加,训练变得越来越昂贵,并且策略性能通常会因虚假相关性而降低。最近的方法通常通过截断上下文长度来回避这些问题,从而丢弃了可能对后续决策至关重要的历史信息。本文提出了一种显式地正则化过去信息保留的替代方法。我们首先重新审视模仿学习中的copycat问题,并发现最近的扩散策略中存在相反的挑战:它们常常无法捕捉过去和未来动作之间必不可少的依赖关系,而不是过度依赖先前的动作。为了解决这个问题,我们引入了Past-Token Prediction (PTP),这是一个辅助任务,其中策略学习预测过去的动作token以及未来的动作token。这种正则化显着改善了策略头中的时间建模,而对视觉表示的依赖性最小。基于此观察,我们进一步引入了一种多阶段训练策略:使用短上下文预训练视觉编码器,并使用缓存的长上下文嵌入微调策略头。这种策略保留了PTP的优势,同时大大降低了内存和计算开销。最后,我们将PTP扩展到测试时的自我验证机制,使策略能够在推理过程中对与过去动作一致的候选者进行评分和选择。在四个真实世界和六个模拟任务中的实验表明,我们提出的方法将长上下文扩散策略的性能提高了3倍,并将策略训练加速了10倍以上。
🔬 方法详解
问题定义:论文旨在解决长序列机器人任务中,Diffusion策略难以有效利用长上下文信息的问题。现有方法通常通过截断上下文长度来降低计算复杂度,但会丢失重要的历史信息,导致策略性能下降,无法捕捉过去和未来动作之间的依赖关系。
核心思路:论文的核心思路是通过引入Past-Token Prediction (PTP) 辅助任务,显式地正则化策略对过去信息的保留。PTP要求策略不仅预测未来的动作token,还要预测过去的动作token,从而迫使策略学习捕捉过去和未来动作之间的依赖关系,提升时间建模能力。
技术框架:整体框架包含三个主要阶段:1) 使用短上下文预训练视觉编码器;2) 使用缓存的长上下文嵌入和PTP辅助任务微调策略头;3) 在测试时,利用PTP进行自我验证,对候选动作进行评分和选择。策略网络采用Diffusion模型,输入包括视觉信息和动作序列,输出为预测的动作分布。
关键创新:最重要的创新点在于PTP辅助任务的设计。与传统的模仿学习方法不同,PTP不是简单地复制过去的动作,而是通过预测过去动作token来学习过去和未来动作之间的潜在依赖关系。此外,多阶段训练策略和测试时的自我验证机制也进一步提升了性能。
关键设计:PTP的损失函数通常是交叉熵损失,用于衡量预测的过去动作token与真实token之间的差异。多阶段训练策略的关键在于平衡视觉编码器和策略头的训练,避免过拟合。测试时的自我验证机制通过计算候选动作与过去动作的一致性得分来选择最佳动作。
🖼️ 关键图片
📊 实验亮点
实验结果表明,提出的PTP方法在多个真实世界和模拟机器人任务中,将长上下文Diffusion策略的性能提高了3倍,并将策略训练加速了10倍以上。相较于基线方法,PTP能够更有效地利用长上下文信息,从而实现更优的策略性能。
🎯 应用场景
该研究成果可应用于各种需要长时序依赖关系的机器人任务,例如复杂操作、长期导航、人机协作等。通过提升策略对历史信息的利用能力,可以显著提高机器人的决策质量和任务完成效率,使其在更复杂的环境中自主工作。
📄 摘要(原文)
Reasoning over long sequences of observations and actions is essential for many robotic tasks. Yet, learning effective long-context policies from demonstrations remains challenging. As context length increases, training becomes increasingly expensive due to rising memory demands, and policy performance often degrades as a result of spurious correlations. Recent methods typically sidestep these issues by truncating context length, discarding historical information that may be critical for subsequent decisions. In this paper, we propose an alternative approach that explicitly regularizes the retention of past information. We first revisit the copycat problem in imitation learning and identify an opposite challenge in recent diffusion policies: rather than over-relying on prior actions, they often fail to capture essential dependencies between past and future actions. To address this, we introduce Past-Token Prediction (PTP), an auxiliary task in which the policy learns to predict past action tokens alongside future ones. This regularization significantly improves temporal modeling in the policy head, with minimal reliance on visual representations. Building on this observation, we further introduce a multistage training strategy: pre-train the visual encoder with short contexts, and fine-tune the policy head using cached long-context embeddings. This strategy preserves the benefits of PTP while greatly reducing memory and computational overhead. Finally, we extend PTP into a self-verification mechanism at test time, enabling the policy to score and select candidates consistent with past actions during inference. Experiments across four real-world and six simulated tasks demonstrate that our proposed method improves the performance of long-context diffusion policies by 3x and accelerates policy training by more than 10x.