MaDi: Learning to Mask Distractions for Generalization in Visual Deep Reinforcement Learning
作者: Bram Grooten, Tristan Tomilin, Gautham Vasan, Matthew E. Taylor, A. Rupam Mahmood, Meng Fang, Mykola Pechenizkiy, Decebal Constantin Mocanu
分类: cs.LG, cs.AI, cs.CV, cs.RO
发布日期: 2023-12-23
备注: Accepted as full-paper (oral) at AAMAS 2024. Code is available at https://github.com/bramgrooten/mask-distractions and see our 40-second video at https://youtu.be/2oImF0h1k48
💡 一句话要点
MaDi:学习掩蔽视觉深度强化学习中的干扰,提升泛化能力
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 视觉强化学习 泛化能力 注意力机制 干扰抑制 深度学习
📋 核心要点
- 现有视觉强化学习方法在处理干扰信息时,依赖数据增强或大型辅助网络,计算成本高昂。
- MaDi算法通过引入轻量级的Masker网络,动态生成掩码,使智能体专注于学习任务相关的信息。
- 实验表明,MaDi在多个基准测试中实现了优于或媲美SOTA的泛化性能,且参数增加极少。
📝 摘要(中文)
视觉世界包含大量信息,但智能体接收到的许多输入像素通常包含干扰性刺激。自主智能体需要能够区分有用信息和与任务无关的感知,从而泛化到具有新干扰的未见环境。现有工作通常使用数据增强或带有额外损失函数的大型辅助网络来解决这个问题。我们提出了一种新颖的算法MaDi,它仅通过奖励信号学习掩蔽干扰。在MaDi中,深度强化学习智能体的传统actor-critic结构由一个小的第三个分支Masker补充。这个轻量级神经网络生成一个掩码,以确定actor和critic将接收什么,以便它们可以专注于学习任务。掩码是动态创建的,具体取决于当前输入。我们在DeepMind Control Generalization Benchmark、Distracting Control Suite和一个真实的UR5机械臂上进行了实验。我们的算法通过有用的掩码提高了智能体的注意力,而其高效的Masker网络仅在原始结构中增加了0.2%的参数,这与之前的工作形成对比。MaDi始终如一地实现了优于或可与最先进方法相媲美的泛化结果。
🔬 方法详解
问题定义:在视觉深度强化学习中,智能体面临着大量与任务无关的干扰信息,这些信息会降低学习效率和泛化能力。现有方法,如数据增强和引入大型辅助网络,虽然可以缓解这个问题,但通常计算成本高,或者需要额外的损失函数进行约束。因此,如何在不显著增加计算负担的情况下,让智能体能够自动识别并忽略干扰信息,是一个重要的挑战。
核心思路:MaDi的核心思路是引入一个轻量级的Masker网络,该网络根据当前输入动态生成一个掩码,用于过滤掉输入图像中的干扰信息。Actor和Critic网络只接收经过掩码处理后的图像,从而专注于学习与任务相关的特征。Masker网络通过奖励信号进行训练,鼓励其生成能够提高智能体性能的掩码。
技术框架:MaDi的整体架构是在传统的Actor-Critic框架的基础上,增加了一个Masker网络。Masker网络接收原始输入图像,输出一个与输入图像大小相同的掩码。该掩码与原始图像相乘,得到经过过滤的图像。Actor和Critic网络接收经过过滤的图像作为输入,并分别输出动作和价值估计。整个系统通过强化学习算法进行端到端训练。
关键创新:MaDi的关键创新在于Masker网络的引入,它能够动态地学习掩蔽干扰信息,而无需额外的监督信号或损失函数。与现有方法相比,MaDi的Masker网络非常轻量级,只增加了少量参数,因此不会显著增加计算负担。此外,MaDi的掩码生成过程是动态的,可以根据不同的输入自适应地调整掩码,从而更好地适应不同的环境。
关键设计:Masker网络通常采用卷积神经网络结构,输入为原始图像,输出为与输入图像大小相同的单通道掩码。掩码的值域通常被限制在[0, 1]之间,可以使用Sigmoid函数来实现。Masker网络的损失函数通常是Actor-Critic网络的损失函数的负值,即鼓励Masker网络生成能够提高智能体性能的掩码。为了保证训练的稳定性,可以对Masker网络的输出进行正则化,例如使用L1或L2正则化。
📊 实验亮点
MaDi算法在DeepMind Control Generalization Benchmark、Distracting Control Suite和真实UR5机械臂等多个基准测试中进行了评估。实验结果表明,MaDi算法在泛化性能方面优于或可与最先进的方法相媲美,同时只增加了0.2%的参数量。例如,在某些任务中,MaDi算法的性能比基线方法提高了10%以上。
🎯 应用场景
MaDi算法在机器人控制、自动驾驶、游戏AI等领域具有广泛的应用前景。它可以帮助智能体在复杂的、充满干扰的环境中更好地学习和泛化,提高智能体的鲁棒性和适应性。例如,在自动驾驶中,MaDi可以帮助车辆忽略道路上的无关信息,如广告牌、行人等,从而更专注于驾驶任务。
📄 摘要(原文)
The visual world provides an abundance of information, but many input pixels received by agents often contain distracting stimuli. Autonomous agents need the ability to distinguish useful information from task-irrelevant perceptions, enabling them to generalize to unseen environments with new distractions. Existing works approach this problem using data augmentation or large auxiliary networks with additional loss functions. We introduce MaDi, a novel algorithm that learns to mask distractions by the reward signal only. In MaDi, the conventional actor-critic structure of deep reinforcement learning agents is complemented by a small third sibling, the Masker. This lightweight neural network generates a mask to determine what the actor and critic will receive, such that they can focus on learning the task. The masks are created dynamically, depending on the current input. We run experiments on the DeepMind Control Generalization Benchmark, the Distracting Control Suite, and a real UR5 Robotic Arm. Our algorithm improves the agent's focus with useful masks, while its efficient Masker network only adds 0.2% more parameters to the original structure, in contrast to previous work. MaDi consistently achieves generalization results better than or competitive to state-of-the-art methods.