Learning Transformer-based World Models with Contrastive Predictive Coding

📄 arXiv: 2503.04416v2 📥 PDF

作者: Maxime Burchi, Radu Timofte

分类: cs.LG, cs.AI, cs.CV

发布日期: 2025-03-06 (更新: 2025-05-25)


💡 一句话要点

提出TWISTER:基于对比预测编码学习Transformer世界模型,提升强化学习性能

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 世界模型 Transformer 对比预测编码 强化学习 Atari 100k

📋 核心要点

  1. 现有基于Transformer的世界模型在强化学习中表现不如基于RNN的模型,原因在于简单地预测下一个状态无法充分利用Transformer的潜力。
  2. TWISTER通过引入动作条件对比预测编码,学习高层次的时间特征表示,从而扩展世界模型的预测范围,提升智能体的性能。
  3. TWISTER在Atari 100k基准测试中取得了162%的人类标准化平均得分,超越了其他不使用前瞻搜索的先进方法。

📝 摘要(中文)

DreamerV3算法通过学习基于循环神经网络(RNN)的精确世界模型,在各种环境领域取得了显著的性能。受模型强化学习算法的成功以及Transformer架构在训练效率和可扩展性方面的快速普及的推动,诸如STORM等最新工作提出了使用基于掩码自注意力机制的Transformer世界模型来替代基于RNN的世界模型。然而,尽管这些方法提高了训练效率,但与Dreamer算法相比,它们对性能的影响仍然有限,难以学习具有竞争力的Transformer世界模型。本文表明,先前方法中采用的下一个状态预测目标不足以充分利用Transformer的表示能力。我们提出通过引入TWISTER(Transformer-based World model wIth contraSTivE Representations),一种使用动作条件对比预测编码来学习高级时间特征表示并提高智能体性能的世界模型,将世界模型预测扩展到更长的时间范围。TWISTER在Atari 100k基准测试中实现了162%的人类标准化平均得分,在不采用前瞻搜索的最新方法中创下了新纪录。

🔬 方法详解

问题定义:现有基于Transformer的世界模型在强化学习任务中,性能与基于RNN的DreamerV3相比仍有差距。核心问题在于,简单地使用Transformer预测下一个状态,无法充分利用Transformer强大的表征能力,尤其是在处理长期依赖关系时。现有方法未能有效提取和利用环境中的时间信息,导致学习到的世界模型不够准确和鲁棒。

核心思路:TWISTER的核心思路是利用对比预测编码(Contrastive Predictive Coding, CPC)来学习环境的高级时间特征表示。通过最大化未来状态和当前状态表示之间的互信息,TWISTER能够学习到更具判别性的特征,从而更好地捕捉环境的动态变化。动作条件对比预测编码则进一步考虑了动作对状态转移的影响,使得学习到的表示更加符合实际情况。

技术框架:TWISTER的整体框架包括以下几个主要模块:1) 编码器:将原始观测数据编码为低维状态表示。2) 动作嵌入:将动作信息嵌入到状态表示中。3) Transformer模型:利用Transformer模型学习状态表示的时间依赖关系,预测未来状态的表示。4) 对比预测模块:使用对比损失函数,最大化预测的未来状态表示和实际未来状态表示之间的互信息。整个流程是,智能体与环境交互,收集经验数据,然后利用这些数据训练TWISTER世界模型。

关键创新:TWISTER的关键创新在于将对比预测编码与Transformer世界模型相结合。与传统的下一个状态预测方法相比,对比预测编码能够学习到更具鲁棒性和泛化能力的特征表示,从而提高世界模型的预测精度。此外,动作条件对比预测编码的设计,使得模型能够更好地理解动作对环境的影响,从而做出更明智的决策。

关键设计:TWISTER的关键设计包括:1) 对比损失函数:使用InfoNCE损失函数来最大化未来状态和当前状态表示之间的互信息。2) Transformer结构:采用标准的Transformer结构,并根据具体任务调整Transformer的层数和注意力头数。3) 动作嵌入方式:将动作信息通过线性层嵌入到状态表示中,并与状态表示相加。4) 训练策略:采用离线训练的方式,利用收集到的经验数据训练TWISTER世界模型。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

TWISTER在Atari 100k基准测试中取得了显著的成果,实现了162%的人类标准化平均得分,超越了其他不使用前瞻搜索的先进方法,例如DreamerV3。这一结果表明,TWISTER能够有效地学习环境的动态变化,并利用这些知识做出更明智的决策。实验结果验证了对比预测编码在Transformer世界模型中的有效性,为未来的研究提供了新的方向。

🎯 应用场景

TWISTER具有广泛的应用前景,包括机器人控制、游戏AI、自动驾驶等领域。通过学习精确的世界模型,智能体可以更好地理解环境的动态变化,从而做出更明智的决策。此外,TWISTER还可以用于生成合成数据,用于训练其他机器学习模型,提高模型的泛化能力。未来,TWISTER有望成为构建通用人工智能的重要组成部分。

📄 摘要(原文)

The DreamerV3 algorithm recently obtained remarkable performance across diverse environment domains by learning an accurate world model based on Recurrent Neural Networks (RNNs). Following the success of model-based reinforcement learning algorithms and the rapid adoption of the Transformer architecture for its superior training efficiency and favorable scaling properties, recent works such as STORM have proposed replacing RNN-based world models with Transformer-based world models using masked self-attention. However, despite the improved training efficiency of these methods, their impact on performance remains limited compared to the Dreamer algorithm, struggling to learn competitive Transformer-based world models. In this work, we show that the next state prediction objective adopted in previous approaches is insufficient to fully exploit the representation capabilities of Transformers. We propose to extend world model predictions to longer time horizons by introducing TWISTER (Transformer-based World model wIth contraSTivE Representations), a world model using action-conditioned Contrastive Predictive Coding to learn high-level temporal feature representations and improve the agent performance. TWISTER achieves a human-normalized mean score of 162% on the Atari 100k benchmark, setting a new record among state-of-the-art methods that do not employ look-ahead search.