Offline Reinforcement Learning with Wasserstein Regularization via Optimal Transport Maps

📄 arXiv: 2507.10843v1 📥 PDF

作者: Motoki Omura, Yusuke Mukuta, Kazuki Ota, Takayuki Osa, Tatsuya Harada

分类: cs.LG, cs.AI, cs.RO

发布日期: 2025-07-14

备注: Accepted at RLC 2025

🔗 代码/项目: GITHUB


💡 一句话要点

提出基于最优传输映射和Wasserstein正则化的离线强化学习方法,解决分布偏移问题。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 离线强化学习 Wasserstein距离 最优传输映射 分布偏移 正则化

📋 核心要点

  1. 离线强化学习面临分布偏移问题,导致策略在训练数据之外表现不佳,现有方法如f-散度正则化对此问题不够鲁棒。
  2. 论文提出使用Wasserstein距离进行正则化,该距离对分布外数据更鲁棒,并能捕捉动作间的相似性,从而提升策略泛化能力。
  3. 通过输入凸神经网络建模最优传输映射,实现了无判别器的Wasserstein距离计算,避免了对抗训练,并在D4RL数据集上验证了有效性。

📝 摘要(中文)

离线强化学习(RL)旨在从静态数据集中学习最优策略,这在数据收集成本高昂的场景(如机器人技术)中尤其有价值。离线RL的一个主要挑战是分布偏移,即学习到的策略偏离数据集分布,可能导致不可靠的越界动作。为了缓解这个问题,已经采用了正则化技术。虽然许多现有方法使用基于密度比率的度量(例如$f$-散度)进行正则化,但我们提出了一种利用Wasserstein距离的方法,该方法对越界数据具有鲁棒性,并能捕捉动作之间的相似性。我们的方法采用输入凸神经网络(ICNN)来建模最优传输映射,从而以无判别器的方式计算Wasserstein距离,从而避免了对抗训练并确保了稳定的学习。在D4RL基准数据集上,我们的方法表现出与广泛使用的现有方法相当或更优越的性能。代码可在https://github.com/motokiomura/Q-DOT 获取。

🔬 方法详解

问题定义:离线强化学习旨在利用静态数据集训练策略,但由于策略可能探索到训练数据未覆盖的状态-动作空间,导致分布偏移问题。现有方法,如基于f-散度的正则化,在处理分布外数据时不够鲁棒,容易导致策略性能下降。因此,需要一种更鲁棒的正则化方法来约束策略行为,使其更接近数据集分布。

核心思路:论文的核心思路是利用Wasserstein距离来度量策略生成的动作分布与数据集中的动作分布之间的差异,并将其作为正则化项加入到目标函数中。Wasserstein距离相比于f-散度,对分布外数据更加鲁棒,能够更好地捕捉动作之间的相似性,从而引导策略学习更安全的行为。

技术框架:整体框架包括一个策略网络和一个Q函数网络。策略网络负责生成动作,Q函数网络负责评估状态-动作对的价值。Wasserstein距离通过最优传输映射来计算,该映射由输入凸神经网络(ICNN)建模。训练过程中,策略网络的目标是最大化Q函数值,同时最小化与数据集动作分布的Wasserstein距离。Q函数网络的目标是准确估计状态-动作对的价值。

关键创新:最重要的技术创新点在于使用最优传输映射来计算Wasserstein距离,并将其应用于离线强化学习的正则化。传统的Wasserstein距离计算通常需要求解一个复杂的优化问题,而通过ICNN建模最优传输映射,可以将其转化为一个简单的前向传播过程,大大提高了计算效率。此外,该方法避免了对抗训练,从而保证了训练的稳定性。

关键设计:关键设计包括:1) 使用输入凸神经网络(ICNN)来建模最优传输映射,确保映射的凸性,从而保证Wasserstein距离的有效性。2) 将Wasserstein距离作为正则化项加入到策略网络的损失函数中,平衡策略性能和安全性。3) 采用合适的优化算法来训练策略网络和Q函数网络,例如Adam优化器。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

该方法在D4RL基准数据集上进行了评估,结果表明,与现有的离线强化学习方法相比,该方法取得了相当或更优越的性能。尤其是在处理分布偏移问题时,该方法的优势更加明显。实验结果验证了Wasserstein正则化在离线强化学习中的有效性。

🎯 应用场景

该研究成果可应用于机器人控制、自动驾驶、医疗诊断等领域。在这些领域中,数据收集成本高昂或存在安全风险,离线强化学习具有重要价值。通过Wasserstein正则化,可以提高离线学习策略的鲁棒性和安全性,使其能够更好地适应真实环境中的不确定性。

📄 摘要(原文)

Offline reinforcement learning (RL) aims to learn an optimal policy from a static dataset, making it particularly valuable in scenarios where data collection is costly, such as robotics. A major challenge in offline RL is distributional shift, where the learned policy deviates from the dataset distribution, potentially leading to unreliable out-of-distribution actions. To mitigate this issue, regularization techniques have been employed. While many existing methods utilize density ratio-based measures, such as the $f$-divergence, for regularization, we propose an approach that utilizes the Wasserstein distance, which is robust to out-of-distribution data and captures the similarity between actions. Our method employs input-convex neural networks (ICNNs) to model optimal transport maps, enabling the computation of the Wasserstein distance in a discriminator-free manner, thereby avoiding adversarial training and ensuring stable learning. Our approach demonstrates comparable or superior performance to widely used existing methods on the D4RL benchmark dataset. The code is available at https://github.com/motokiomura/Q-DOT .