Tackling Data Corruption in Offline Reinforcement Learning via Sequence Modeling

📄 arXiv: 2407.04285v4 📥 PDF

作者: Jiawei Xu, Rui Yang, Shuang Qiu, Feng Luo, Meng Fang, Baoxiang Wang, Lei Han

分类: cs.LG, cs.AI

发布日期: 2024-07-05 (更新: 2025-03-02)

备注: Accepted by ICLR2025

🔗 代码/项目: GITHUB


💡 一句话要点

提出RDT,通过序列建模解决离线强化学习中的数据损坏问题

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 离线强化学习 数据损坏 序列建模 鲁棒性 决策转换器

📋 核心要点

  1. 现有离线强化学习方法在有限且受损的数据集上表现不佳,无法有效应对真实世界数据中的噪声和错误。
  2. 论文提出Robust Decision Transformer (RDT),通过嵌入dropout、高斯加权学习和迭代数据校正增强模型鲁棒性。
  3. 实验表明,RDT在MuJoCo、Kitchen和Adroit任务上,相比现有方法,在各种数据损坏场景下均表现出更优性能。

📝 摘要(中文)

离线强化学习旨在从离线数据集中学习策略,避免不安全和高成本的在线交互,从而扩展数据驱动的决策。然而,现实世界中传感器或人类收集的数据通常包含噪声和错误,这对现有的离线强化学习方法构成了重大挑战,尤其是在数据有限的情况下。研究表明,基于时序差分学习的离线强化学习方法在数据损坏且数据集有限时表现不佳。相比之下,诸如Decision Transformer之类的序列建模方法即使没有专门的修改,也表现出对数据损坏的鲁棒性。为了充分发挥序列建模的潜力,我们提出了一种鲁棒的决策转换器(RDT),它结合了三种简单而有效的鲁棒技术:嵌入dropout以提高模型对错误输入的鲁棒性,高斯加权学习以减轻损坏标签的影响,以及迭代数据校正以消除源数据中损坏的数据。在MuJoCo、Kitchen和Adroit任务上的大量实验表明,与先前的方法相比,RDT在各种数据损坏场景下表现出卓越的性能。此外,RDT在更具挑战性的环境中表现出显著的鲁棒性,该环境结合了训练时的数据损坏和测试时的观察扰动。这些结果突出了序列建模在从噪声或损坏的离线数据集中学习的潜力,从而促进了离线强化学习在现实世界中的可靠应用。

🔬 方法详解

问题定义:论文旨在解决离线强化学习中,由于离线数据集存在数据损坏(data corruption)而导致策略学习效果下降的问题。现有基于时序差分学习的离线强化学习方法,在面对噪声或错误数据时,尤其是在数据量有限的情况下,鲁棒性不足,难以学习到有效的策略。

核心思路:论文的核心思路是利用序列建模方法(特别是Decision Transformer)本身对数据损坏的天然鲁棒性,并在此基础上进一步增强其鲁棒性。通过引入一系列简单有效的技术,使模型能够更好地从受损的离线数据集中学习,从而提升策略性能。

技术框架:RDT的整体框架基于Decision Transformer,它将强化学习问题转化为序列建模问题。主要包含以下几个阶段:1)数据预处理:对离线数据集进行初步处理,包括状态、动作、奖励等信息的提取。2)模型训练:使用改进的Decision Transformer模型进行训练,包括嵌入dropout、高斯加权学习等技术。3)迭代数据校正:通过模型预测结果,识别并消除数据集中潜在的损坏数据。4)策略评估:在测试环境中评估学习到的策略性能。

关键创新:论文的关键创新在于提出了三种简单但有效的鲁棒性增强技术,并将其集成到Decision Transformer框架中:1)嵌入dropout:通过随机丢弃部分嵌入向量,增强模型对输入噪声的鲁棒性。2)高斯加权学习:使用高斯权重对损失函数进行加权,降低损坏标签对模型训练的影响。3)迭代数据校正:通过模型预测结果,识别并消除数据集中潜在的损坏数据,从而提高数据集质量。

关键设计:1)嵌入dropout:dropout概率设置为一个合适的值,以平衡模型的表达能力和鲁棒性。2)高斯加权学习:高斯权重的方差根据数据集的噪声水平进行调整。3)迭代数据校正:设定一个阈值,用于判断数据是否损坏,并将其从数据集中移除。损失函数采用标准的序列建模损失函数,网络结构与Decision Transformer保持一致,采用Transformer架构。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,RDT在MuJoCo、Kitchen和Adroit等多个任务上,相比于基线方法,在各种数据损坏场景下均取得了显著的性能提升。尤其是在训练时数据损坏和测试时观察扰动同时存在的情况下,RDT依然表现出强大的鲁棒性,验证了其在复杂环境下的适用性。

🎯 应用场景

该研究成果可应用于各种需要从噪声或损坏的离线数据集中学习策略的场景,例如机器人控制、自动驾驶、医疗诊断等。通过提高离线强化学习算法的鲁棒性,可以降低对数据质量的要求,从而更容易地将强化学习技术应用于实际问题中,具有重要的实际应用价值和潜力。

📄 摘要(原文)

Learning policy from offline datasets through offline reinforcement learning (RL) holds promise for scaling data-driven decision-making while avoiding unsafe and costly online interactions. However, real-world data collected from sensors or humans often contains noise and errors, posing a significant challenge for existing offline RL methods, particularly when the real-world data is limited. Our study reveals that prior research focusing on adapting predominant offline RL methods based on temporal difference learning still falls short under data corruption when the dataset is limited. In contrast, we discover that vanilla sequence modeling methods, such as Decision Transformer, exhibit robustness against data corruption, even without specialized modifications. To unlock the full potential of sequence modeling, we propose Robust Decision Rransformer (RDT) by incorporating three simple yet effective robust techniques: embedding dropout to improve the model's robustness against erroneous inputs, Gaussian weighted learning to mitigate the effects of corrupted labels, and iterative data correction to eliminate corrupted data from the source. Extensive experiments on MuJoCo, Kitchen, and Adroit tasks demonstrate RDT's superior performance under various data corruption scenarios compared to prior methods. Furthermore, RDT exhibits remarkable robustness in a more challenging setting that combines training-time data corruption with test-time observation perturbations. These results highlight the potential of sequence modeling for learning from noisy or corrupted offline datasets, thereby promoting the reliable application of offline RL in real-world scenarios. Our code is available at https://github.com/jiawei415/RobustDecisionTransformer.