JEDI: Joint Embedding Diffusion World Model for Online Model-Based Reinforcement Learning
作者: Jing Yu Lim, Rushi Shah, Zarif Ikram, Samson Yu, Haozhe Ma, Tze-Yun Leong, Dianbo Liu
分类: cs.LG
发布日期: 2026-05-13
💡 一句话要点
提出JEDI:一种用于在线模型强化学习的联合嵌入扩散世界模型
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 扩散模型 世界模型 强化学习 在线学习 潜在空间 预测表征学习 JEPA 端到端学习
📋 核心要点
- 现有基于扩散的世界模型在像素空间计算成本高,在潜在空间性能不足,且依赖预训练的潜在空间。
- JEDI通过联合嵌入扩散的方式,直接从扩散去噪损失中学习潜在空间,实现端到端的训练。
- 实验表明,JEDI在Atari100k上具有竞争力,且相比像素扩散基线,显著降低了计算资源消耗并提升了训练速度。
📝 摘要(中文)
扩散世界模型最近在在线模型强化学习中展现出竞争力,但现有方法存在一个矛盾:像素扩散虽然有效,但计算成本高昂;最新的潜在扩散方法提高了效率,但性能欠佳。后者还依赖于单独训练的潜在空间,而不是驱动现代MBRL进展的端到端世界模型目标。特别是,JEPA风格的预测表征学习已成为世界建模和MBRL的一个有希望的方向。同时,扩散风格的目标在多个领域受到关注,迭代细化是多模态和随机目标的一个有希望的方法。综上所述,这些趋势促使我们提出了联合嵌入扩散(JEDI),这是第一个在线端到端潜在扩散世界模型。JEDI直接从具有JEPA框架的扩散去噪损失中学习其潜在空间,使用去噪来学习和预测未来的潜在空间,而不是依赖于重建和预训练模型。我们提供了一个理论动机,表明传统的JEPA目标会诱导预测信息瓶颈,并且条件扩散去噪允许密切相关的预测压缩分解。在Atari100k上,JEDI具有竞争力,并且优于具有单独训练的潜在空间的基线。相对于像素扩散基线,JEDI使用的VRAM减少了43%,世界模型采样速度提高了3倍以上,训练速度提高了2.5倍。JEDI还表现出与像素基线明显不同的任务级性能曲线,这表明端到端预测潜在空间的变化不仅仅是计算上的。
🔬 方法详解
问题定义:现有基于扩散的世界模型在在线模型强化学习中面临计算效率和性能之间的权衡。像素扩散模型计算成本高昂,而依赖预训练潜在空间的潜在扩散模型性能不佳,并且无法实现端到端的优化。因此,如何设计一个既高效又高性能的端到端扩散世界模型是一个关键问题。
核心思路:JEDI的核心思路是利用联合嵌入扩散的方式,直接从扩散去噪损失中学习潜在空间,并结合JEPA(Joint Embedding Predictive Architecture)框架进行预测表征学习。通过这种方式,JEDI避免了对预训练潜在空间的依赖,实现了端到端的优化,从而在计算效率和性能之间取得更好的平衡。
技术框架:JEDI的整体框架包括以下几个主要模块:1) 观测编码器:将原始观测数据编码到潜在空间;2) 扩散模型:在潜在空间中进行扩散和去噪操作,学习潜在空间的分布;3) JEPA预测器:基于当前潜在状态预测未来的潜在状态;4) 强化学习策略:基于预测的未来状态进行决策。整个流程是端到端可训练的,通过联合优化扩散模型和JEPA预测器,实现对环境的有效建模。
关键创新:JEDI最重要的技术创新点在于其端到端的潜在空间学习方式。与以往依赖预训练潜在空间的方法不同,JEDI直接从扩散去噪损失中学习潜在空间,并结合JEPA框架进行预测表征学习。这种方式使得JEDI能够更好地适应特定任务,并实现更高的性能。此外,论文还从理论上证明了传统JEPA目标会诱导预测信息瓶颈,而条件扩散去噪允许密切相关的预测压缩分解。
关键设计:JEDI的关键设计包括:1) 使用扩散模型进行潜在空间的建模,通过迭代去噪的方式学习潜在空间的分布;2) 采用JEPA框架进行预测表征学习,通过预测未来的潜在状态来学习环境的动态特性;3) 设计合适的损失函数,包括扩散去噪损失和预测损失,以实现对扩散模型和JEPA预测器的联合优化;4) 针对Atari100k等具体任务,选择合适的网络结构和超参数,以获得最佳性能。
🖼️ 关键图片
📊 实验亮点
JEDI在Atari100k上取得了显著的实验结果。与像素扩散基线相比,JEDI使用的VRAM减少了43%,世界模型采样速度提高了3倍以上,训练速度提高了2.5倍。此外,JEDI还表现出与像素基线明显不同的任务级性能曲线,表明端到端预测潜在空间能够带来性能提升。JEDI也优于使用单独训练的潜在空间的基线。
🎯 应用场景
JEDI具有广泛的应用前景,可应用于机器人控制、游戏AI、自动驾驶等领域。通过学习环境的动态模型,JEDI可以帮助智能体更好地理解环境,并做出更明智的决策。此外,JEDI的端到端训练方式和高效的计算性能,使其能够应用于在线强化学习场景,从而实现智能体的实时学习和适应。
📄 摘要(原文)
Diffusion world models have recently become competitive for online model-based reinforcement learning, but current approaches expose a tension: pixel diffusion is effective but computationally expensive while the latest latent diffusion approach improves efficiency yet performs subpar. The latter also relies on separately trained latents rather than the end-to-end world-model objectives that have driven much of modern MBRL progress. In particular, JEPA-style predictive representation learning has emerged as an especially promising direction for world modeling and MBRL. Concurrently, diffusion-style objectives have gained traction across multiple domains, with iterative refinement as a promising approach for multimodal and stochastic targets. Taken together, these trends motivate Joint Embedding DIffusion (JEDI), the first online end-to-end latent diffusion world model. JEDI learns its latent space directly from the diffusion denoising loss with a JEPA framework, using denoising to learn and predict future latents rather than relying on reconstruction and pretrained models. We provide a theoretical motivation showing that conventional JEPA objectives induce a predictive information bottleneck, and that conditional diffusion denoising admits a closely related predictive-compression decomposition. Empirically, JEDI is competitive on Atari100k and outperforms the baseline with seperately trained latents where directly comparable. Relative to the pixel diffusion baseline, JEDI uses 43% less VRAM, over 3$\times$ faster world-model sampling, and 2.5$\times$ faster training. JEDI also exhibits a markedly different task-level performance profile from the pixel baseline, suggesting that end-to-end predictive latents change more than compute alone.