Pre-trained Language Models Improve the Few-shot Prompt Ability of Decision Transformer

📄 arXiv: 2408.01402v2 📥 PDF

作者: Yu Yang, Pan Xu

分类: cs.LG, cs.AI, cs.CL

发布日期: 2024-08-02 (更新: 2025-12-02)

备注: 2 figures, 10 tables. Published in Transactions on Machine Learning Research (TMLR)

期刊: Transactions on Machine Learning Research, 2025


💡 一句话要点

提出LPDT框架,利用预训练语言模型提升Decision Transformer的少样本Prompt能力

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 离线强化学习 Decision Transformer 预训练语言模型 少样本学习 Prompt学习

📋 核心要点

  1. Prompt-DT方法依赖于特定环境的数据提示,数据收集成本高且不安全,导致少样本能力受限。
  2. LPDT框架利用预训练语言模型提供先验知识,并结合LoRA微调和提示正则化,提升任务区分能力。
  3. 实验表明,LPDT在少量数据下可达到与Prompt-DT相似的性能,验证了各组件的有效性。

📝 摘要(中文)

Decision Transformer (DT) 作为一种离线强化学习算法,利用预先收集的数据集和Transformer建模长序列的能力。最近的研究表明,在DT中使用来自训练任务的部分轨迹作为提示,可以提高其在新任务上的性能,从而产生了Prompt-DT方法。然而,在许多场景中,从特定环境中收集数据既昂贵又不安全,由于基于Transformer的模型对数据需求量大,导致性能欠佳和少样本提示能力有限。此外,预训练中使用的数据集有限,使得Prompt-DT类型的方法难以仅通过提示来区分各种强化学习任务。为了解决这些挑战,我们引入了语言模型初始化的Prompt Decision Transformer (LPDT) 框架,该框架利用预训练语言模型为强化学习任务提供丰富的先验知识,并使用低秩适应 (LoRA) 对序列模型进行微调,以解决元强化学习问题。我们进一步结合了提示正则化,以有效地根据提示特征表示来区分任务。全面的实验研究表明,使用预训练语言模型进行初始化可以提供先验知识,并在某些MuJoCo控制任务中仅使用10%的数据即可实现与Prompt-DT相似的性能。我们还提供了全面的消融研究,以验证每个组件的有效性,包括序列建模、语言模型、提示正则化和提示策略。

🔬 方法详解

问题定义:现有的Prompt-DT方法在离线强化学习中依赖于从特定环境中收集的轨迹数据作为prompt,以提升在新任务上的泛化能力。然而,在许多实际场景中,收集这些数据既昂贵又不安全。此外,由于Transformer模型本身的数据饥渴特性,Prompt-DT在数据量有限的情况下表现不佳,其少样本prompt能力受到限制。更重要的是,Prompt-DT难以仅通过有限的prompt数据区分不同的强化学习任务。

核心思路:LPDT的核心思路是利用预训练语言模型(PLM)所蕴含的丰富先验知识来弥补强化学习数据不足的问题。通过将PLM的知识迁移到Decision Transformer中,LPDT能够更好地理解和区分不同的强化学习任务,从而提升其少样本prompt能力。同时,采用低秩适应(LoRA)方法进行微调,可以在减少训练参数的同时,有效地将PLM的知识融入到DT模型中。

技术框架:LPDT框架主要包含以下几个模块:1) 预训练语言模型:提供丰富的先验知识。2) Decision Transformer:作为序列建模的主体。3) LoRA微调:将PLM的知识迁移到DT模型中,同时减少训练参数。4) Prompt正则化:通过约束prompt的特征表示,增强模型区分不同任务的能力。整体流程是:首先使用预训练语言模型初始化DT模型,然后使用LoRA方法在少量强化学习数据上进行微调,同时应用prompt正则化。

关键创新:LPDT的关键创新在于将预训练语言模型引入到Prompt-DT框架中,利用PLM的先验知识来解决数据不足的问题。与传统的Prompt-DT方法相比,LPDT不再完全依赖于特定环境的数据prompt,而是能够利用PLM的通用知识来提升模型的泛化能力和少样本学习能力。此外,prompt正则化的引入进一步增强了模型区分不同任务的能力。

关键设计:LPDT的关键设计包括:1) 使用预训练语言模型(例如,GPT系列)初始化Decision Transformer的Transformer编码器。2) 采用LoRA进行微调,只训练少量参数,避免灾难性遗忘。3) 引入prompt正则化项,例如,通过最小化不同任务prompt特征表示之间的距离,来增强模型区分任务的能力。具体的损失函数设计和网络结构选择需要根据具体的任务和数据集进行调整。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,LPDT在MuJoCo控制任务中,仅使用10%的数据即可达到与传统Prompt-DT方法相似的性能。消融实验验证了预训练语言模型、LoRA微调和prompt正则化等各个组件的有效性。这些结果表明,LPDT能够有效地利用预训练语言模型的先验知识,提升Decision Transformer的少样本prompt能力。

🎯 应用场景

LPDT框架可应用于各种离线强化学习任务,尤其是在数据收集成本高昂或安全性受限的场景中,例如机器人控制、自动驾驶和医疗决策等。该研究有助于降低强化学习算法对数据的依赖,提高其在实际应用中的可行性和效率,并促进强化学习在更广泛领域的应用。

📄 摘要(原文)

Decision Transformer (DT) has emerged as a promising class of algorithms in offline reinforcement learning (RL) tasks, leveraging pre-collected datasets and Transformer's capability to model long sequences. Recent works have demonstrated that using parts of trajectories from training tasks as prompts in DT enhances its performance on unseen tasks, giving rise to Prompt-DT methods. However, collecting data from specific environments can be both costly and unsafe in many scenarios, leading to suboptimal performance and limited few-shot prompt abilities due to the data-hungry nature of Transformer-based models. Additionally, the limited datasets used in pre-training make it challenging for Prompt-DT type of methods to distinguish between various RL tasks through prompts alone. To address these challenges, we introduce the Language model-initialized Prompt Decision Transformer (LPDT) framework, which leverages pretrained language models providing rich prior knowledge for RL tasks and fine-tunes the sequence model using Low-rank Adaptation (LoRA) for meta-RL problems. We further incorporate prompt regularization to effectively differentiate between tasks based on prompt feature representations. Comprehensive empirical studies demonstrate that initializing with a pre-trained language model provides the prior knowledge and achieves a similar performance with Prompt-DT under only $10\%$ data in some MuJoCo control tasks. We also provide a thorough ablation study to validate the effectiveness of each component, including sequence modeling, language models, prompt regularizations, and prompt strategies.