UWM-JEPA: Predictive World Models That Imagine in Belief Space
作者: Santosh Kumar Radha, Oktay Goktas
分类: cs.LG, cs.AI, cs.RO, stat.ML
发布日期: 2026-05-25
备注: 14 pages, 6 figures, 7 tables. Code and data: https://github.com/santoshkumarradha/uwm-jepa
💡 一句话要点
提出UWM-JEPA,利用密度矩阵和酉预测器提升部分可观测环境下的世界模型预测能力
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 世界模型 部分可观测 密度矩阵 酉预测器 联合嵌入 反事实推理 强化学习
📋 核心要点
- 现有世界模型在部分可观测环境下难以准确预测未来,尤其是在反事实动作下的表现不佳。
- UWM-JEPA使用密度矩阵作为潜在变量,并学习酉预测器,以保持rollout期间的不确定性。
- 实验表明,UWM-JEPA在隐藏速度指示器任务和盲rollout中优于LSTM-JEPA等基线模型。
📝 摘要(中文)
针对部分可观测环境下的世界模型,需要预测多个兼容的隐藏未来并在反事实行为下进行引导的问题,本文提出了Unitary World Model JEPA (UWM-JEPA)。UWM-JEPA是一种JEPA世界模型,它在联合系统-环境空间上使用密度矩阵作为潜在变量,并学习一个酉预测器。这种结构在rollout期间精确地保留了联合状态谱,因此预测器本身无法消散所表示的不确定性。在隐藏速度指示器任务中,UWM-JEPA达到了0.77的准确率,并且随着动作扰动单调下降。相比之下,参数匹配的LSTM-JEPA在每个动作条件下都崩溃到多数类准确率(0.53)。在盲rollout下,UWM-JEPA在短horizon内损失的探针R^2少于10个点,而向量潜在基线损失了41和68个点。动作敏感性本身需要针对反事实而非教师强制目标进行训练,这一发现适用于酉参数化之外。对于JEPA世界模型在部分可观测下进行想象,潜在几何和预测器动态很重要,而不仅仅是冻结的上下文编码能力。
🔬 方法详解
问题定义:论文旨在解决部分可观测环境下,世界模型难以准确预测未来状态的问题。现有方法,如基于向量潜在变量的JEPA,在盲rollout中会丢失大量信息,无法有效处理不确定性,并且对反事实动作的敏感性不足。
核心思路:论文的核心思路是使用密度矩阵来表示潜在状态,并学习一个酉预测器。密度矩阵能够更好地表示不确定性,而酉预测器能够保证在rollout过程中信息不会丢失,从而提高预测的准确性和鲁棒性。
技术框架:UWM-JEPA的整体框架包括一个编码器、一个潜在空间(密度矩阵)和一个酉预测器。编码器将观测输入映射到潜在空间中的密度矩阵。酉预测器根据当前状态和动作预测下一个状态的密度矩阵。通过rollout,可以预测未来多个时间步的状态。
关键创新:最重要的技术创新点是使用密度矩阵作为潜在变量和学习酉预测器。与传统的向量潜在变量相比,密度矩阵能够更好地表示不确定性。酉预测器保证了在rollout过程中信息不会丢失,避免了不确定性的消散。
关键设计:论文中使用了特定的损失函数来训练模型,包括重构损失和对比损失。网络结构方面,编码器可以使用卷积神经网络或Transformer等。酉预测器可以使用参数化的酉矩阵或学习到的酉变换。具体的参数设置和网络结构需要根据具体的任务进行调整。
🖼️ 关键图片
📊 实验亮点
UWM-JEPA在隐藏速度指示器任务中达到了0.77的准确率,显著优于LSTM-JEPA的0.53。在盲rollout中,UWM-JEPA在短horizon内损失的探针R^2少于10个点,而向量潜在基线损失了41和68个点。这些结果表明,UWM-JEPA在处理不确定性和进行长期预测方面具有显著优势。
🎯 应用场景
该研究成果可应用于机器人导航、游戏AI、自动驾驶等领域。在这些领域中,智能体需要在部分可观测的环境下进行决策,并预测未来状态。UWM-JEPA能够提高智能体对环境的理解和预测能力,从而做出更明智的决策,提升任务完成的效率和安全性。
📄 摘要(原文)
World models for partially observed environments must imagine multiple compatible hidden futures and steer between them under counterfactual actions. Joint Embedding Predictive Architectures (JEPAs) do this in latent space, but a vector-valued latent has no internal structure for carrying the belief over hidden continuations through blind rollout. We introduce the Unitary World Model JEPA (UWM-JEPA), a JEPA world model with a density-matrix latent on a joint system-environment space and a learned unitary predictor. The construction preserves the joint-state spectrum exactly during rollout, so the predictor itself cannot dissipate the represented uncertainty. On a hidden-velocity indicator task requiring five-step forward simulation under a given action sequence with the target observation masked, UWM-JEPA reaches 0.77 accuracy and degrades monotonically as actions are perturbed; a parameter-matched LSTM-JEPA trained under the same counterfactual-target objective and action head collapses to majority-class accuracy (0.53) under every action condition. Under blind rollout, UWM-JEPA loses fewer than ten points of probe R^2 at short horizons while vector-latent baselines lose forty-one and sixty-eight; both nevertheless tie on a held-out context probe, locating the separation in the predictor rather than the encoder. Action sensitivity itself requires training against counterfactual rather than teacher-forced targets, a finding that applies beyond the unitary parameterisation. For JEPA world models to imagine under partial observability, latent geometry and predictor dynamics matter, not frozen context-encoding capacity alone.