LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels

📄 arXiv: 2603.19312v2 📥 PDF

作者: Lucas Maes, Quentin Le Lidec, Damien Scieur, Yann LeCun, Randall Balestriero

分类: cs.LG, cs.AI

发布日期: 2026-03-13 (更新: 2026-03-24)


💡 一句话要点

LeWorldModel:提出一种稳定的端到端像素级联合嵌入预测架构,用于学习世界模型。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 世界模型 联合嵌入 预测架构 端到端学习 机器人控制

📋 核心要点

  1. 现有JEPAs方法在学习世界模型时依赖复杂损失函数和预训练等,存在训练脆弱性和依赖性问题。
  2. LeWorldModel通过简化损失函数,仅使用预测损失和高斯正则化,实现了端到端稳定训练。
  3. LeWM在控制任务中表现出竞争力的同时,规划速度提升显著,且能有效编码物理结构并检测异常事件。

📝 摘要(中文)

联合嵌入预测架构(JEPAs)为在紧凑的潜在空间中学习世界模型提供了一个引人注目的框架,但现有方法仍然脆弱,依赖于复杂的多项损失、指数移动平均、预训练编码器或辅助监督来避免表征崩溃。本文介绍LeWorldModel (LeWM),这是第一个从原始像素端到端稳定训练的JEPA,仅使用两个损失项:下一个嵌入预测损失和一个强制高斯分布潜在嵌入的正则化项。与现有的唯一端到端替代方案相比,这减少了可调损失超参数,从六个减少到一个。LeWM具有约1500万个参数,可在单个GPU上在几个小时内训练,其规划速度比基于基础模型的世界模型快48倍,同时在各种2D和3D控制任务中保持竞争力。除了控制之外,我们还表明LeWM的潜在空间通过物理量的探测编码了有意义的物理结构。惊奇评估证实,该模型可靠地检测到物理上不合理的事件。

🔬 方法详解

问题定义:现有基于联合嵌入预测架构(JEPA)的世界模型学习方法,为了避免表征坍塌,通常需要复杂的多项损失函数、指数移动平均、预训练编码器或辅助监督。这些方法增加了训练的复杂性,使得模型对超参数敏感,训练过程不稳定,并且依赖于额外的资源或先验知识。

核心思路:LeWorldModel的核心思路是通过极简的损失函数设计,实现从原始像素到紧凑潜在空间的稳定端到端训练。具体来说,只使用两个损失项:一个是预测下一个嵌入的损失,另一个是强制潜在嵌入服从高斯分布的正则化项。这种设计旨在减少超参数的调整,简化训练过程,并提高模型的泛化能力。

技术框架:LeWorldModel的整体架构包含一个编码器和一个预测器。编码器将原始像素输入映射到潜在空间中的嵌入表示。预测器接收当前嵌入作为输入,并预测下一个时间步的嵌入。模型通过最小化预测损失和高斯正则化损失进行训练。在推理阶段,可以使用预测器进行未来状态的预测和规划。

关键创新:LeWorldModel最重要的创新在于其极简的损失函数设计,它只包含两个损失项,分别是预测损失和高斯正则化损失。这种设计显著简化了训练过程,提高了训练的稳定性,并减少了对超参数的依赖。与现有方法相比,LeWM是第一个能够从原始像素端到端稳定训练的JEPA模型。

关键设计:LeWM的关键设计包括:1) 使用Transformer架构作为编码器和预测器,以捕获时间序列数据中的长期依赖关系;2) 使用L2损失作为预测损失,衡量预测嵌入和真实嵌入之间的差异;3) 使用KL散度作为高斯正则化损失,强制潜在嵌入服从标准高斯分布;4) 通过控制高斯正则化损失的权重,平衡预测精度和潜在空间的结构化程度。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

LeWorldModel在多个2D和3D控制任务中表现出与现有方法相当的性能,同时规划速度提升高达48倍。此外,LeWM的潜在空间能够编码有意义的物理结构,并且能够可靠地检测物理上不合理的事件。这些结果表明LeWM是一种高效、稳定且具有良好泛化能力的通用世界模型学习方法。

🎯 应用场景

LeWorldModel具有广泛的应用前景,包括机器人控制、自动驾驶、游戏AI等领域。它可以用于学习复杂环境的动态模型,从而使智能体能够更好地进行规划和决策。此外,LeWM还可以用于异常检测,例如在视频监控中检测不寻常的事件。该研究的成果有助于推动通用人工智能的发展。

📄 摘要(原文)

Joint Embedding Predictive Architectures (JEPAs) offer a compelling framework for learning world models in compact latent spaces, yet existing methods remain fragile, relying on complex multi-term losses, exponential moving averages, pre-trained encoders, or auxiliary supervision to avoid representation collapse. In this work, we introduce LeWorldModel (LeWM), the first JEPA that trains stably end-to-end from raw pixels using only two loss terms: a next-embedding prediction loss and a regularizer enforcing Gaussian-distributed latent embeddings. This reduces tunable loss hyperparameters from six to one compared to the only existing end-to-end alternative. With ~15M parameters trainable on a single GPU in a few hours, LeWM plans up to 48x faster than foundation-model-based world models while remaining competitive across diverse 2D and 3D control tasks. Beyond control, we show that LeWM's latent space encodes meaningful physical structure through probing of physical quantities. Surprise evaluation confirms that the model reliably detects physically implausible events.