Contrastive Representation for Data Filtering in Cross-Domain Offline Reinforcement Learning
作者: Xiaoyu Wen, Chenjia Bai, Kang Xu, Xudong Yu, Yang Zhang, Xuelong Li, Zhen Wang
分类: cs.LG, cs.AI
发布日期: 2024-05-10
备注: This paper has been accepted by ICML2024
💡 一句话要点
提出基于对比表示的数据过滤方法,解决跨域离线强化学习中的数据异构问题。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 跨域强化学习 离线强化学习 对比学习 数据过滤 表示学习
📋 核心要点
- 跨域离线强化学习中,直接合并不同动态特性的源域和目标域数据会导致性能下降,这是核心问题。
- 论文提出通过对比学习不同域的转移样本,学习域不变表示,以此衡量域间动态差异并进行数据过滤。
- 实验表明,该方法仅使用少量目标域数据,即可达到甚至超过使用全部目标域数据的现有方法的性能。
📝 摘要(中文)
本文提出了一种基于对比表示的跨域离线强化学习数据过滤方法,旨在解决源域和目标域数据动态不匹配导致性能下降的问题。该方法通过对比学习不同域的转移样本来学习表示,从而衡量域之间的差异。这种对比目标能够有效捕捉转移函数之间的互信息差距,避免了直接测量动态差距时可能出现的不稳定问题。基于学习到的表示,本文提出了一种数据过滤算法,根据对比得分函数选择性地共享源域的转移样本。实验结果表明,该方法在各种任务上均取得了优异的性能,仅使用10%的目标数据即可达到使用100%目标数据集的先进方法的89.2%的性能。
🔬 方法详解
问题定义:跨域离线强化学习旨在利用源域数据提升目标域的策略学习效果。然而,源域和目标域的动态特性差异(dynamics mismatch)会导致负迁移,直接混合两个域的数据反而会降低性能。现有方法通常依赖于领域分类器来衡量动态差距,但这些方法依赖于配对域的可迁移性假设,并且在处理差异显著的域时,动态差距的度量可能变得不稳定。
核心思路:本文的核心思路是通过学习一种能够区分不同域转移样本的表示,来衡量域之间的差异。这种表示学习的目标不是直接建模动态差距,而是通过对比学习,捕捉不同域转移函数之间的互信息差距。互信息差距能够更稳定地反映域之间的差异,避免了动态差距度量的不稳定性。
技术框架:该方法主要包含两个阶段:表示学习阶段和数据过滤阶段。在表示学习阶段,使用对比学习目标,训练一个编码器,将状态-动作对映射到表示空间。对比学习的目标是拉近同一域的转移样本的表示,推远不同域的转移样本的表示。在数据过滤阶段,使用学习到的表示,计算源域转移样本与目标域转移样本之间的对比得分。根据对比得分,选择性地将源域转移样本添加到目标域数据集中。
关键创新:该方法最重要的创新点在于使用对比学习来衡量域之间的差异,而不是直接建模动态差距。对比学习能够更稳定地捕捉不同域转移函数之间的互信息差距,避免了动态差距度量的不稳定性。此外,该方法不需要配对域的可迁移性假设,适用性更广。
关键设计:对比学习的损失函数采用InfoNCE损失,用于最大化同一域转移样本表示之间的一致性,最小化不同域转移样本表示之间的一致性。数据过滤阶段,使用对比得分作为权重,对源域转移样本进行加权,然后将加权后的源域数据与目标域数据合并,用于训练强化学习策略。具体的网络结构和超参数设置根据不同的任务进行调整。
📊 实验亮点
实验结果表明,该方法在多个跨域离线强化学习任务上取得了显著的性能提升。例如,在某些任务上,仅使用10%的目标数据,该方法即可达到使用100%目标数据集的现有方法的89.2%的性能。与现有方法相比,该方法在数据利用率和性能方面均具有优势。
🎯 应用场景
该研究成果可应用于机器人控制、自动驾驶、游戏AI等领域,尤其是在数据获取成本高昂或难以直接在目标环境中进行探索的场景下。通过利用其他相关领域的数据,可以显著降低对目标环境数据的依赖,加速策略学习过程,提高智能系统的泛化能力和鲁棒性。
📄 摘要(原文)
Cross-domain offline reinforcement learning leverages source domain data with diverse transition dynamics to alleviate the data requirement for the target domain. However, simply merging the data of two domains leads to performance degradation due to the dynamics mismatch. Existing methods address this problem by measuring the dynamics gap via domain classifiers while relying on the assumptions of the transferability of paired domains. In this paper, we propose a novel representation-based approach to measure the domain gap, where the representation is learned through a contrastive objective by sampling transitions from different domains. We show that such an objective recovers the mutual-information gap of transition functions in two domains without suffering from the unbounded issue of the dynamics gap in handling significantly different domains. Based on the representations, we introduce a data filtering algorithm that selectively shares transitions from the source domain according to the contrastive score functions. Empirical results on various tasks demonstrate that our method achieves superior performance, using only 10% of the target data to achieve 89.2% of the performance on 100% target dataset with state-of-the-art methods.