Why and How Auxiliary Tasks Improve JEPA Representations
作者: Jiacan Yu, Siyi Chen, Mingrui Liu, Nono Horiuchi, Vladimir Braverman, Zicheng Xu, Dan Haramati, Randall Balestriero
分类: cs.LG, cs.AI
发布日期: 2025-09-12 (更新: 2025-10-19)
💡 一句话要点
通过辅助任务改进JEPA表征,解决表征坍塌问题并提升表征质量
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 联合嵌入预测架构 JEPA 辅助任务 表征学习 表征坍塌
📋 核心要点
- JEPA在视觉表征学习和强化学习中应用广泛,但其内在机制尚不明确,存在表征坍塌的风险。
- 论文提出联合训练潜在动态和辅助回归头的JEPA变体,通过辅助任务锚定表征,避免不良表征坍塌。
- 在计数环境中的实验表明,联合训练的JEPA模型能生成更丰富的表征,验证了理论的有效性。
📝 摘要(中文)
联合嵌入预测架构(JEPA)越来越多地被用于视觉表征学习和基于模型的强化学习中,但其行为仍然缺乏深入理解。本文对一个简单的、实用的JEPA变体进行了理论分析,该变体具有一个与潜在动态联合训练的辅助回归头。我们证明了一个“无不良表征坍塌”定理:在确定性MDP中,如果训练使潜在转移一致性损失和辅助回归损失都趋于零,那么任何一对非等价的观测,即那些不具有相同转移动态或辅助值的观测,必须映射到不同的潜在表征。因此,辅助任务锚定了表征必须保留的区别。在计数环境中的受控消融实验证实了该理论,并表明与单独训练相比,将JEPA模型与辅助头联合训练可以生成更丰富的表征。我们的工作为改进JEPA编码器提供了一条途径:使用一个辅助函数对其进行训练,该函数与转移动态一起,编码正确的等价关系。
🔬 方法详解
问题定义:现有的JEPA模型在训练过程中,容易出现表征坍塌的问题,即不同的观测可能被映射到相同的潜在表征,导致模型无法区分不同的状态或动作。这限制了JEPA在视觉表征学习和强化学习中的应用效果。现有方法缺乏对JEPA表征能力的理论分析,难以指导模型设计和训练。
核心思路:论文的核心思路是通过引入辅助任务来约束JEPA的表征空间,防止表征坍塌。具体而言,论文提出联合训练潜在动态和辅助回归头的JEPA变体。辅助任务的目标是预测与观测相关的某些属性或特征。通过同时优化潜在转移一致性损失和辅助回归损失,可以确保JEPA学习到的潜在表征能够区分具有不同转移动态或辅助值的观测。
技术框架:该方法基于标准的JEPA框架,主要包含编码器、潜在转移模型和辅助回归头三个模块。编码器将观测映射到潜在表征空间。潜在转移模型预测下一个潜在状态。辅助回归头基于潜在表征预测辅助任务的目标值。整个框架通过最小化潜在转移一致性损失和辅助回归损失进行端到端训练。
关键创新:论文的关键创新在于提出了“无不良表征坍塌”定理,并从理论上证明了辅助任务在防止表征坍塌中的作用。该定理表明,只要辅助任务能够编码正确的等价关系,联合训练JEPA模型和辅助头就可以保证学习到具有区分性的潜在表征。
关键设计:论文的关键设计包括:1) 选择合适的辅助任务,使其能够编码与任务相关的关键信息;2) 设计合适的辅助回归头,使其能够准确预测辅助任务的目标值;3) 平衡潜在转移一致性损失和辅助回归损失的权重,以确保模型能够同时学习到潜在动态和辅助任务的信息。
📊 实验亮点
论文在计数环境中进行了受控消融实验,结果表明,与单独训练相比,将JEPA模型与辅助头联合训练可以生成更丰富的表征。具体而言,联合训练的模型能够更准确地预测环境中的计数信息,并且在下游任务中表现更好。这些实验结果验证了理论的有效性,并表明辅助任务可以有效地改善JEPA模型的表征能力。
🎯 应用场景
该研究成果可应用于视觉表征学习、强化学习等领域。通过引入合适的辅助任务,可以提升JEPA模型的表征能力,从而改善下游任务的性能。例如,在机器人控制中,可以利用辅助任务预测机器人的状态或环境的属性,从而提高机器人的感知和决策能力。该研究也为设计更有效的自监督学习方法提供了新的思路。
📄 摘要(原文)
Joint-Embedding Predictive Architecture (JEPA) is increasingly used for visual representation learning and as a component in model-based RL, but its behavior remains poorly understood. We provide a theoretical characterization of a simple, practical JEPA variant that has an auxiliary regression head trained jointly with latent dynamics. We prove a No Unhealthy Representation Collapse theorem: in deterministic MDPs, if training drives both the latent-transition consistency loss and the auxiliary regression loss to zero, then any pair of non-equivalent observations, i.e., those that do not have the same transition dynamics or auxiliary value, must map to distinct latent representations. Thus, the auxiliary task anchors which distinctions the representation must preserve. Controlled ablations in a counting environment corroborate the theory and show that training the JEPA model jointly with the auxiliary head generates a richer representation than training them separately. Our work indicates a path to improve JEPA encoders: training them with an auxiliary function that, together with the transition dynamics, encodes the right equivalence relations.