GTA: Generative Trajectory Augmentation with Guidance for Offline Reinforcement Learning

📄 arXiv: 2405.16907v5 📥 PDF

作者: Jaewoo Lee, Sujin Yun, Taeyoung Yun, Jinkyoo Park

分类: cs.AI, cs.LG

发布日期: 2024-05-27 (更新: 2024-11-07)

备注: NeurIPS 2024. Previously accepted (Spotlight) to ICLR 2024 Workshop on Generative Models for Decision Making. Jaewoo Lee and Sujin Yun are equal contribution authors

🔗 代码/项目: GITHUB


💡 一句话要点

提出GTA:一种基于生成轨迹增强的离线强化学习方法,提升数据质量。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱四:生成式动作 (Generative Motion)

关键词: 离线强化学习 数据增强 生成模型 扩散模型 轨迹生成

📋 核心要点

  1. 离线强化学习依赖静态数据集,现有数据增强方法难以直接提升数据集质量,导致学习效果受限。
  2. GTA利用扩散模型,通过对轨迹进行噪声处理和有条件去噪,生成高回报且动态合理的增强轨迹。
  3. 实验表明,GTA能有效提升多种离线强化学习算法在不同任务上的性能,并显著改善数据集质量。

📝 摘要(中文)

离线强化学习面临着从静态数据集中学习有效决策策略的挑战,而无需任何在线交互。数据增强技术,如噪声注入和数据合成,旨在通过平滑学习到的状态-动作区域来改进Q函数近似。然而,这些方法通常无法直接提高离线数据集的质量,导致次优结果。为了解决这个问题,我们引入了GTA,即生成轨迹增强,这是一种新颖的生成数据增强方法,旨在通过增强轨迹,使其既具有高回报又在动态上合理,从而丰富离线数据。GTA在数据增强框架内应用扩散模型。GTA部分地对原始轨迹进行噪声处理,然后通过无分类器引导,以放大的回报值为条件进行去噪。我们的结果表明,GTA作为一种通用的数据增强策略,增强了广泛使用的离线强化学习算法在各种具有独特挑战的任务中的性能。此外,我们对GTA增强的数据进行了质量分析,并证明GTA提高了数据的质量。我们的代码可在https://github.com/Jaewoopudding/GTA获取。

🔬 方法详解

问题定义:离线强化学习旨在从静态数据集中学习策略,但现有数据增强方法(如噪声注入)无法有效提升数据集质量,导致Q函数近似不准确,策略学习效果不佳。核心问题在于如何生成高质量、既能提高回报又能保证动态合理性的增强数据。

核心思路:GTA的核心思路是利用生成模型(扩散模型)来生成高质量的增强轨迹。通过对原始轨迹添加噪声,然后利用条件扩散模型进行去噪,生成既能获得高回报,又符合环境动态特性的新轨迹。这种方法旨在弥补离线数据集的不足,提高Q函数学习的准确性。

技术框架:GTA的整体框架包括以下步骤:1) 从离线数据集中采样原始轨迹;2) 对采样的轨迹添加噪声,使其部分损坏;3) 使用条件扩散模型对噪声轨迹进行去噪,其中条件是放大的回报值;4) 将生成的增强轨迹添加到原始数据集中,用于训练离线强化学习算法。扩散模型在这里充当轨迹生成器的角色。

关键创新:GTA的关键创新在于将扩散模型引入到离线强化学习的数据增强中,并使用无分类器引导,以放大的回报值为条件进行轨迹生成。这种方法能够生成既具有高回报,又符合环境动态特性的增强轨迹,从而有效提升离线强化学习算法的性能。与传统的噪声注入方法相比,GTA能够生成更具结构性和信息量的增强数据。

关键设计:GTA的关键设计包括:1) 使用扩散模型作为轨迹生成器,扩散模型的具体结构(例如,U-Net)和训练方式需要根据具体任务进行调整;2) 使用无分类器引导,通过调整回报值的放大系数来控制生成轨迹的回报水平;3) 损失函数的设计需要保证生成轨迹的动态合理性,例如,可以使用轨迹的预测误差作为损失函数的一部分。具体参数设置需要根据实验结果进行调整。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,GTA在多个离线强化学习任务中显著提升了性能。例如,在D4RL benchmark上,GTA能够提升SAC、CQL等主流算法的性能,平均提升幅度超过10%。此外,对增强数据质量的分析表明,GTA生成的轨迹具有更高的回报和更强的动态合理性。

🎯 应用场景

GTA可广泛应用于离线强化学习领域,尤其适用于数据稀缺或质量不高的场景。例如,在医疗诊断、自动驾驶、金融交易等领域,可以利用历史数据训练智能体,并通过GTA生成高质量的增强数据,提升智能体的决策能力。该方法还有助于降低在线探索的成本和风险。

📄 摘要(原文)

Offline Reinforcement Learning (Offline RL) presents challenges of learning effective decision-making policies from static datasets without any online interactions. Data augmentation techniques, such as noise injection and data synthesizing, aim to improve Q-function approximation by smoothing the learned state-action region. However, these methods often fall short of directly improving the quality of offline datasets, leading to suboptimal results. In response, we introduce GTA, Generative Trajectory Augmentation, a novel generative data augmentation approach designed to enrich offline data by augmenting trajectories to be both high-rewarding and dynamically plausible. GTA applies a diffusion model within the data augmentation framework. GTA partially noises original trajectories and then denoises them with classifier-free guidance via conditioning on amplified return value. Our results show that GTA, as a general data augmentation strategy, enhances the performance of widely used offline RL algorithms across various tasks with unique challenges. Furthermore, we conduct a quality analysis of data augmented by GTA and demonstrate that GTA improves the quality of the data. Our code is available at https://github.com/Jaewoopudding/GTA