Diffusion Policies creating a Trust Region for Offline Reinforcement Learning

📄 arXiv: 2405.19690v3 📥 PDF

作者: Tianyu Chen, Zhendong Wang, Mingyuan Zhou

分类: cs.LG, cs.AI

发布日期: 2024-05-30 (更新: 2024-10-31)

备注: NeurIPS 2024

🔗 代码/项目: GITHUB


💡 一句话要点

提出DTQL:通过扩散信任域加速离线强化学习,兼顾性能与效率

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 离线强化学习 扩散模型 信任域 行为克隆 Q学习

📋 核心要点

  1. 现有离线强化学习方法,如DQL,依赖迭代去噪采样生成动作,导致训练和推理速度慢。
  2. DTQL提出双策略框架,结合扩散策略的表达能力和单步策略的效率,通过信任域损失连接。
  3. 实验表明,DTQL在D4RL基准测试中优于其他方法,并在训练和推理速度上表现出更高的效率。

📝 摘要(中文)

离线强化学习(RL)利用预先收集的数据集来训练最优策略。扩散Q学习(DQL)引入扩散模型作为一种强大且富有表现力的策略类,显著提升了离线RL的性能。然而,它依赖于迭代去噪采样来生成动作,这减慢了训练和推理速度。虽然最近的一些尝试试图加速扩散-QL,但训练和/或推理速度的提高往往会导致性能下降。本文提出了一种双策略方法,即扩散信任Q学习(DTQL),它包含一个用于纯行为克隆的扩散策略和一个实用的单步策略。我们通过新引入的扩散信任域损失来桥接这两个策略。扩散策略保持了表达能力,而信任域损失引导单步策略在扩散策略定义的区域内自由探索并寻找模式。DTQL消除了训练和推理过程中对迭代去噪采样的需求,使其具有显著的计算效率。我们在2D bandit场景和gym任务中评估了其有效性和算法特性,并与流行的基于Kullback--Leibler散度的蒸馏方法进行了比较。我们进一步表明,DTQL不仅在大多数D4RL基准任务上优于其他方法,而且在训练和推理速度方面也表现出效率。PyTorch实现可在https://github.com/TianyuCodings/Diffusion_Trusted_Q_Learning获得。

🔬 方法详解

问题定义:离线强化学习旨在利用预先收集的数据集训练最优策略,而无需与环境交互。DQL等方法虽然性能优异,但其依赖迭代去噪采样生成动作,计算成本高昂,严重限制了训练和推理速度。如何在保证性能的同时,提升离线强化学习的效率是一个关键问题。

核心思路:DTQL的核心思路是结合扩散模型的表达能力和单步策略的效率。通过训练一个扩散策略进行行为克隆,并利用一个单步策略进行快速决策。关键在于引入扩散信任域损失,引导单步策略在扩散策略定义的信任区域内探索,从而在保证策略质量的同时,避免了耗时的迭代采样。

技术框架:DTQL包含两个主要模块:扩散策略和单步策略。扩散策略通过行为克隆学习数据集中的行为模式,提供一个可信的策略分布。单步策略则通过Q学习等方法进行训练,目标是在扩散策略的信任区域内找到最优动作。扩散信任域损失用于约束单步策略的动作分布,使其接近扩散策略的输出。

关键创新:DTQL的关键创新在于引入了扩散信任域损失,它有效地桥接了扩散策略和单步策略。与传统的KL散度蒸馏方法不同,扩散信任域损失允许单步策略在扩散策略定义的区域内自由探索,从而更好地利用数据集中的信息,并避免了策略坍塌的问题。此外,DTQL完全消除了推理阶段的迭代采样,显著提升了推理速度。

关键设计:DTQL的关键设计包括:1) 扩散策略的网络结构选择,通常采用标准的扩散模型架构;2) 单步策略的训练方法,可以使用DQN、SAC等算法;3) 扩散信任域损失的具体形式,例如可以使用MMD距离或Wasserstein距离来衡量两个策略分布的差异;4) 信任区域的定义方式,可以通过设置阈值来限制单步策略的动作与扩散策略输出的距离。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

DTQL在D4RL基准测试中取得了显著的性能提升,在多个任务上超越了现有的离线强化学习算法。更重要的是,DTQL在训练和推理速度方面表现出更高的效率,与DQL相比,推理速度提升了数倍,使其更具实用价值。

🎯 应用场景

DTQL可应用于各种需要离线强化学习的场景,例如机器人控制、自动驾驶、推荐系统和金融交易。该方法尤其适用于计算资源有限或对推理速度有较高要求的应用,例如在嵌入式设备上部署智能体或进行实时决策。

📄 摘要(原文)

Offline reinforcement learning (RL) leverages pre-collected datasets to train optimal policies. Diffusion Q-Learning (DQL), introducing diffusion models as a powerful and expressive policy class, significantly boosts the performance of offline RL. However, its reliance on iterative denoising sampling to generate actions slows down both training and inference. While several recent attempts have tried to accelerate diffusion-QL, the improvement in training and/or inference speed often results in degraded performance. In this paper, we introduce a dual policy approach, Diffusion Trusted Q-Learning (DTQL), which comprises a diffusion policy for pure behavior cloning and a practical one-step policy. We bridge the two polices by a newly introduced diffusion trust region loss. The diffusion policy maintains expressiveness, while the trust region loss directs the one-step policy to explore freely and seek modes within the region defined by the diffusion policy. DTQL eliminates the need for iterative denoising sampling during both training and inference, making it remarkably computationally efficient. We evaluate its effectiveness and algorithmic characteristics against popular Kullback--Leibler divergence-based distillation methods in 2D bandit scenarios and gym tasks. We then show that DTQL could not only outperform other methods on the majority of the D4RL benchmark tasks but also demonstrate efficiency in training and inference speeds. The PyTorch implementation is available at https://github.com/TianyuCodings/Diffusion_Trusted_Q_Learning.