Efficient World Models with Context-Aware Tokenization
作者: Vincent Micheli, Eloi Alonso, François Fleuret
分类: cs.LG, cs.AI, cs.CV
发布日期: 2024-06-27
备注: ICML 2024
🔗 代码/项目: GITHUB
💡 一句话要点
提出Δ-IRIS,通过上下文感知 Tokenization 实现高效世界模型,刷新 Crafter 基准。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 世界模型 强化学习 Transformer 离散自编码器 增量学习
📋 核心要点
- 现有基于Transformer的世界模型计算量大,难以扩展到复杂环境,主要瓶颈在于需要处理长序列的tokens。
- Δ-IRIS通过离散自编码器编码时间步之间的增量信息,并使用自回归Transformer预测未来增量,从而降低了序列长度。
- 实验表明,Δ-IRIS在Crafter基准测试中取得了新的state-of-the-art,并且训练速度比之前的模型快一个数量级。
📝 摘要(中文)
深度强化学习(RL)方法的可扩展性是一个重大挑战。借鉴生成建模的进展,基于模型的RL正成为一个强有力的竞争者。序列建模的最新进展已经产生了有效的基于Transformer的世界模型,但由于需要长序列的tokens来准确模拟环境,计算量巨大。本文提出了Δ-IRIS,一种新的智能体,其世界模型架构由一个离散自编码器组成,该自编码器编码时间步之间的随机增量,以及一个自回归Transformer,通过用连续tokens总结世界的当前状态来预测未来的增量。在Crafter基准测试中,Δ-IRIS在多个帧预算下创造了新的state-of-the-art,同时训练速度比以前基于注意力的模型快一个数量级。我们已在https://github.com/vmicheli/delta-iris上发布了我们的代码和模型。
🔬 方法详解
问题定义:现有基于Transformer的世界模型在深度强化学习中面临计算瓶颈,尤其是在需要模拟复杂环境时。由于需要处理长序列的tokens以准确表示环境状态,导致训练和推理成本过高。这限制了它们在实际应用中的可扩展性。
核心思路:Δ-IRIS的核心思路是通过学习时间步之间的增量(delta)信息来降低序列长度,从而减少计算量。它将环境状态的变化编码为离散的tokens,并使用Transformer来预测这些增量,而不是直接预测整个环境状态。这种方法能够更有效地捕捉环境的动态变化,并减少冗余信息。
技术框架:Δ-IRIS的整体架构包含两个主要模块:一个离散自编码器和一个自回归Transformer。离散自编码器负责将环境状态的增量编码为离散的tokens,从而实现对环境动态的压缩表示。自回归Transformer则利用这些tokens来预测未来的增量,从而模拟环境的演化。整个流程包括:1. 观察环境状态;2. 计算状态增量;3. 使用自编码器编码增量为离散tokens;4. 使用Transformer预测未来增量tokens;5. 解码预测的增量tokens,得到未来状态的预测。
关键创新:Δ-IRIS的关键创新在于使用上下文感知的tokenization方法来表示环境状态的增量。与直接对环境状态进行tokenization相比,对增量进行tokenization能够更有效地捕捉环境的动态变化,并减少冗余信息。此外,使用连续tokens来总结世界的当前状态,为Transformer提供了更丰富的上下文信息,有助于提高预测的准确性。
关键设计:离散自编码器使用VQ-VAE结构,将连续的增量向量映射到离散的codebook中的tokens。Transformer采用标准的自回归结构,并使用因果注意力机制来确保预测的未来增量只依赖于过去的信息。损失函数包括重构损失(用于训练自编码器)和预测损失(用于训练Transformer)。具体的参数设置(如codebook大小、Transformer层数、注意力头数等)需要根据具体环境进行调整。
🖼️ 关键图片
📊 实验亮点
Δ-IRIS在Crafter基准测试中取得了显著的成果,在多个帧预算下创造了新的state-of-the-art。与之前的基于注意力的模型相比,Δ-IRIS的训练速度提高了一个数量级,这表明其具有更高的计算效率。这些结果表明,Δ-IRIS是一种有前途的世界模型方法,能够有效地模拟复杂环境,并为智能体的学习提供有力的支持。
🎯 应用场景
Δ-IRIS具有广泛的应用前景,包括游戏AI、机器人控制、自动驾驶等领域。通过高效地模拟环境,它可以帮助智能体更好地学习策略,并在复杂环境中做出更明智的决策。此外,该方法还可以用于生成逼真的虚拟环境,用于训练和评估AI系统。未来,Δ-IRIS有望成为构建更智能、更强大的AI系统的关键技术。
📄 摘要(原文)
Scaling up deep Reinforcement Learning (RL) methods presents a significant challenge. Following developments in generative modelling, model-based RL positions itself as a strong contender. Recent advances in sequence modelling have led to effective transformer-based world models, albeit at the price of heavy computations due to the long sequences of tokens required to accurately simulate environments. In this work, we propose $Δ$-IRIS, a new agent with a world model architecture composed of a discrete autoencoder that encodes stochastic deltas between time steps and an autoregressive transformer that predicts future deltas by summarizing the current state of the world with continuous tokens. In the Crafter benchmark, $Δ$-IRIS sets a new state of the art at multiple frame budgets, while being an order of magnitude faster to train than previous attention-based approaches. We release our code and models at https://github.com/vmicheli/delta-iris.