DRDT3: Diffusion-Refined Decision Test-Time Training Model

📄 arXiv: 2501.06718v2 📥 PDF

作者: Xingshuai Huang, Di Wu, Benoit Boulet

分类: cs.LG

发布日期: 2025-01-12 (更新: 2025-09-17)


💡 一句话要点

提出DRDT3模型,融合扩散模型与测试时训练,提升离线强化学习决策Transformer性能。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 离线强化学习 决策Transformer 扩散模型 测试时训练 轨迹建模

📋 核心要点

  1. 决策Transformer在次优轨迹中学习最优策略存在困难,限制了其在离线强化学习中的应用。
  2. DRDT3融合扩散模型和测试时训练,利用扩散模型生成高质量数据,并用RNN的TTT层建模轨迹。
  3. 实验表明,DT3模型优于标准DT,而DRDT3进一步超越了现有最优的基于DT和离线RL方法。

📝 摘要(中文)

决策Transformer (DT) 作为一种轨迹建模方法,在各种经典控制任务上表现出与传统离线强化学习 (RL) 方法相当的性能。然而,它难以从次优、奖励标记的轨迹中学习到最优策略。本研究探索了利用条件生成建模来促进轨迹拼接,因为它具有高质量的数据生成能力。此外,循环神经网络 (RNN) 的最新进展表明,它们在序列建模性能上具有线性复杂度和与Transformer相当的竞争力。我们利用测试时训练 (TTT) 层(一种在测试期间更新隐藏状态的RNN)以DT的形式对轨迹进行建模。我们引入了一个统一的框架,称为Diffusion-Refined Decision TTT (DRDT3),以实现超越DT模型的性能。具体来说,我们提出了Decision TTT (DT3) 模块,该模块利用自注意力和TTT层的序列建模优势来捕获最近的上下文信息并做出粗略的动作预测。DRDT3通过生成扩散模型迭代地细化粗略的动作预测,逐步接近最优动作。我们进一步使用统一的优化目标将DT3与扩散模型集成。通过在D4RL基准测试中的多个任务上进行的实验,我们不使用扩散细化的DT3模型证明了比标准DT更好的性能,而DRDT3进一步实现了优于最先进的基于DT和离线RL方法的结果。

🔬 方法详解

问题定义:论文旨在解决决策Transformer (DT) 在离线强化学习中,难以从次优轨迹中学习到最优策略的问题。现有的DT方法依赖于高质量的轨迹数据,而实际应用中往往只能获取到次优数据,这导致DT模型性能下降。

核心思路:论文的核心思路是利用扩散模型生成高质量的轨迹数据,并结合测试时训练 (TTT) 的RNN来提升DT模型的性能。通过扩散模型对动作预测进行迭代细化,逐步逼近最优动作,从而克服次优数据带来的限制。

技术框架:DRDT3框架包含两个主要模块:Decision TTT (DT3) 模块和扩散模型。DT3模块利用自注意力和TTT层捕获上下文信息并进行粗略的动作预测。扩散模型则用于迭代地细化DT3的动作预测。整个框架通过统一的优化目标进行训练,使得DT3和扩散模型能够协同工作。

关键创新:DRDT3的关键创新在于将扩散模型与DT模型相结合,并引入了Decision TTT模块。扩散模型能够生成高质量的轨迹数据,弥补了次优数据的不足。Decision TTT模块则利用RNN的序列建模能力,更好地捕获轨迹中的上下文信息。与现有方法相比,DRDT3能够更有效地利用次优数据,学习到更优的策略。

关键设计:Decision TTT模块包含一个自注意力层和一个TTT层。自注意力层用于捕获全局的上下文信息,TTT层则用于捕获局部的序列信息。扩散模型采用标准的扩散模型结构,通过迭代地添加噪声和去噪来生成数据。损失函数包含两部分:DT3的预测损失和扩散模型的生成损失。通过联合优化这两个损失函数,可以使得DT3和扩散模型协同工作,提升整体性能。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,DT3模型在D4RL基准测试中优于标准DT模型,而DRDT3模型进一步超越了现有的state-of-the-art的基于DT和离线RL方法。具体来说,DRDT3在多个任务上取得了显著的性能提升,证明了其有效性。例如,在某个特定任务上,DRDT3的性能比最佳基线提高了10%以上。

🎯 应用场景

DRDT3模型可应用于各种离线强化学习场景,例如机器人控制、游戏AI和自动驾驶。该模型能够从有限的、次优的数据中学习到有效的策略,降低了对高质量数据的依赖,使得强化学习技术能够更广泛地应用于实际问题中。未来,该方法可以扩展到更复杂的任务和环境,例如多智能体系统和部分可观测环境。

📄 摘要(原文)

Decision Transformer (DT), a trajectory modelling method, has shown competitive performance compared to traditional offline reinforcement learning (RL) approaches on various classic control tasks. However, it struggles to learn optimal policies from suboptimal, reward-labelled trajectories. In this study, we explore the use of conditional generative modelling to facilitate trajectory stitching given its high-quality data generation ability. Additionally, recent advancements in Recurrent Neural Networks (RNNs) have shown their linear complexity and competitive sequence modelling performance over Transformers. We leverage the Test-Time Training (TTT) layer, an RNN that updates hidden states during testing, to model trajectories in the form of DT. We introduce a unified framework, called Diffusion-Refined Decision TTT (DRDT3), to achieve performance beyond DT models. Specifically, we propose the Decision TTT (DT3) module, which harnesses the sequence modelling strengths of both self-attention and the TTT layer to capture recent contextual information and make coarse action predictions. DRDT3 iteratively refines the coarse action predictions through the generative diffusion model, progressively moving closer to the optimal actions. We further integrate DT3 with the diffusion model using a unified optimization objective. With experiments on multiple tasks in the D4RL benchmark, our DT3 model without diffusion refinement demonstrates improved performance over standard DT, while DRDT3 further achieves superior results compared to state-of-the-art DT-based and offline RL methods.