Deep Reinforcement Learning via Object-Centric Attention

📄 arXiv: 2504.03024v1 📥 PDF

作者: Jannis Blüml, Cedric Derstroff, Bjarne Gregori, Elisabeth Dillies, Quentin Delfosse, Kristian Kersting

分类: cs.LG, cs.AI

发布日期: 2025-04-03

备注: 26 pages, 11 figures, 7 tables


💡 一句话要点

提出基于掩码的目标中心注意力机制OCCAM,提升深度强化学习泛化能力。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 深度强化学习 目标中心表示 注意力机制 泛化能力 Atari游戏

📋 核心要点

  1. 深度强化学习智能体在复杂环境中泛化能力差,易受虚假相关性干扰。
  2. OCCAM通过掩码机制实现目标中心注意力,选择性关注任务相关实体,过滤无关信息。
  3. 实验表明,OCCAM在Atari游戏中提高了对扰动的鲁棒性,降低了样本复杂度。

📝 摘要(中文)

在原始像素输入上训练的深度强化学习智能体,常常难以泛化到训练环境之外,过度依赖虚假相关性和不相关的背景细节。为了解决这个问题,最近出现了目标中心智能体。然而,它们需要针对任务规范量身定制不同的表示。与深度智能体相反,没有单一的目标中心架构可以应用于任何环境。受到认知科学原理和奥卡姆剃刀的启发,我们引入了通过掩码实现的目标中心注意力机制(OCCAM),它选择性地保留任务相关的实体,同时过滤掉不相关的视觉信息。具体来说,OCCAM利用了目标中心的归纳偏置。在Atari基准测试上的经验评估表明,与传统的基于像素的强化学习相比,OCCAM显著提高了对新扰动的鲁棒性,降低了样本复杂度,同时表现出相似或改进的性能。这些结果表明,结构化抽象可以增强泛化能力,而无需显式的符号表示或特定领域的对象提取流程。

🔬 方法详解

问题定义:深度强化学习智能体直接从像素输入学习时,容易受到环境中的噪声和无关信息的干扰,导致泛化能力差。现有目标中心方法虽然能提升泛化性,但通常需要针对特定任务设计不同的表示方法,缺乏通用性。因此,如何设计一种通用的、能够自动提取任务相关对象信息的强化学习方法是一个关键问题。

核心思路:论文的核心思路是利用目标中心的归纳偏置,通过注意力机制选择性地关注图像中与任务相关的对象,同时抑制不相关的背景信息。这种方法受到认知科学和奥卡姆剃刀原则的启发,旨在通过最简洁的方式提取关键信息,从而提高智能体的泛化能力。

技术框架:OCCAM的整体框架包括以下几个主要模块:1) 视觉编码器:将原始像素输入编码成特征表示。2) 对象掩码生成器:基于特征表示生成对象掩码,用于选择性地关注图像中的不同区域。3) 注意力机制:利用对象掩码对特征表示进行加权,突出任务相关的对象特征。4) 强化学习策略网络:基于加权后的特征表示学习最优策略。整个流程可以看作是一个端到端的学习过程,通过强化学习的目标函数来优化各个模块的参数。

关键创新:OCCAM的关键创新在于其通过掩码实现目标中心注意力的方式。与传统的注意力机制不同,OCCAM不是直接学习每个像素或特征的重要性,而是学习对象级别的掩码,从而更好地利用了目标中心的归纳偏置。此外,OCCAM不需要预先定义对象或使用复杂的对象检测算法,而是通过学习的方式自动发现任务相关的对象。

关键设计:对象掩码生成器通常采用卷积神经网络结构,输入是视觉编码器的输出特征,输出是每个像素属于某个对象的概率。注意力机制可以使用乘性注意力或加性注意力等不同的形式。损失函数包括强化学习的奖励函数和用于约束掩码的正则化项,例如鼓励掩码稀疏性的L1正则化。具体的网络结构和参数设置需要根据具体的任务进行调整。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

在Atari游戏上的实验结果表明,OCCAM在面对新的扰动时,相比于传统的基于像素的强化学习方法,表现出更强的鲁棒性。同时,OCCAM还降低了样本复杂度,即在更少的训练样本下就能达到相似甚至更好的性能。例如,在某些游戏中,OCCAM的性能提升超过10%。

🎯 应用场景

该研究成果可应用于各种需要智能体具备良好泛化能力的场景,例如机器人导航、自动驾驶、游戏AI等。通过关注任务相关的对象,智能体可以更好地适应新的环境和任务,提高决策的准确性和效率。未来,该方法还可以扩展到多智能体系统和更复杂的任务中。

📄 摘要(原文)

Deep reinforcement learning agents, trained on raw pixel inputs, often fail to generalize beyond their training environments, relying on spurious correlations and irrelevant background details. To address this issue, object-centric agents have recently emerged. However, they require different representations tailored to the task specifications. Contrary to deep agents, no single object-centric architecture can be applied to any environment. Inspired by principles of cognitive science and Occam's Razor, we introduce Object-Centric Attention via Masking (OCCAM), which selectively preserves task-relevant entities while filtering out irrelevant visual information. Specifically, OCCAM takes advantage of the object-centric inductive bias. Empirical evaluations on Atari benchmarks demonstrate that OCCAM significantly improves robustness to novel perturbations and reduces sample complexity while showing similar or improved performance compared to conventional pixel-based RL. These results suggest that structured abstraction can enhance generalization without requiring explicit symbolic representations or domain-specific object extraction pipelines.