Prompt Tuning Decision Transformers with Structured and Scalable Bandits

📄 arXiv: 2502.04979v3 📥 PDF

作者: Finn Rietz, Oleg Smirnov, Sara Karimi, Lele Cao

分类: cs.LG

发布日期: 2025-02-07 (更新: 2025-10-01)

备注: Accepted at NeurIPS 2025


💡 一句话要点

提出基于结构化Bandit的Prompt Tuning决策Transformer,提升离线强化学习多任务泛化能力。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: Prompt Tuning 决策Transformer 离线强化学习 Bandit算法 多任务学习

📋 核心要点

  1. 现有的Prompting决策Transformer (PDT) 通过均匀采样轨迹prompt进行任务泛化,忽略了prompt本身所包含的信息量。
  2. 论文提出一种基于bandit的prompt tuning方法,学习从演示数据中构建最优轨迹prompt,提升prompt的信息量。
  3. 实验结果表明,该方法在各种任务、高维环境和分布外场景中均优于现有prompt tuning基线,性能得到显著提升。

📝 摘要(中文)

Prompt tuning已成为在离线强化学习中调整大型预训练决策Transformer (DTs) 的关键技术,尤其是在多任务和少样本设置中。Prompting决策Transformer (PDT) 通过从专家演示中均匀采样的轨迹prompt实现任务泛化,但没有考虑prompt的信息量。本文提出了一种基于bandit的prompt tuning方法,该方法学习在推理时从演示数据中构建最优轨迹prompt。我们设计了一种在轨迹prompt空间中运行的结构化bandit架构,实现了与prompt大小的线性而非组合缩放。此外,我们表明预训练的PDT本身可以作为bandit的强大特征提取器,从而实现跨各种环境的高效奖励建模。我们在理论上建立了遗憾界限,并通过实验证明,我们的方法在各种任务、高维环境和分布外场景中始终提高性能,优于现有的prompt tuning基线。

🔬 方法详解

问题定义:现有的Prompting决策Transformer (PDT) 在多任务离线强化学习中,通过从专家数据中均匀采样轨迹片段作为prompt,引导决策Transformer进行策略学习。然而,这种均匀采样策略忽略了不同prompt所包含的信息量差异,导致prompt的利用效率不高,限制了模型的泛化能力。因此,如何选择更具信息量的prompt,从而提升决策Transformer的性能,是本文要解决的核心问题。

核心思路:本文的核心思路是利用bandit算法,将prompt的选择过程建模为一个探索-利用问题。通过bandit算法,模型可以在推理时动态地选择最优的轨迹prompt,从而最大化累积奖励。这种方法能够有效地利用专家数据中的信息,提高prompt的利用效率,并提升决策Transformer的泛化能力。

技术框架:该方法主要包含以下几个模块:1) 预训练的Prompting决策Transformer (PDT):作为基础模型,用于策略学习和特征提取。2) 结构化Bandit模块:用于prompt的选择和奖励建模。该模块维护一个prompt候选集,并根据bandit算法选择最优的prompt。3) 奖励模型:用于评估prompt的质量。本文利用预训练的PDT作为特征提取器,为bandit提供高效的奖励建模。整体流程是,在推理时,首先利用PDT提取当前状态的特征,然后利用bandit模块选择最优的prompt,最后将prompt输入到PDT中,生成动作。

关键创新:本文最重要的技术创新点在于提出了一个结构化的bandit架构,该架构能够在轨迹prompt空间中实现线性而非组合缩放。这意味着,随着prompt长度的增加,计算复杂度呈线性增长,从而保证了算法的可扩展性。此外,利用预训练的PDT作为bandit的特征提取器,能够实现跨各种环境的高效奖励建模,进一步提升了算法的性能。

关键设计:在结构化bandit模块中,采用了线性UCB算法进行prompt的选择。奖励模型的设计是关键,本文利用预训练PDT提取的特征,并结合简单的线性模型进行奖励预测。此外,为了保证算法的稳定性,采用了经验回放机制,存储历史的prompt选择和奖励信息,用于bandit的训练。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,该方法在各种任务、高维环境和分布外场景中均优于现有的prompt tuning基线。例如,在D4RL benchmark上,该方法在多个任务上取得了显著的性能提升,平均提升幅度超过10%。此外,该方法在Out-of-Distribution (OOD) 场景下也表现出良好的泛化能力,证明了其在实际应用中的潜力。

🎯 应用场景

该研究成果可应用于各种需要多任务泛化能力的离线强化学习场景,例如机器人控制、自动驾驶、游戏AI等。通过学习最优的轨迹prompt,可以使智能体在新的任务和环境中快速适应,提高其决策能力和泛化性能。此外,该方法还可以用于数据增强,通过生成高质量的prompt,扩充训练数据集,从而提升模型的鲁棒性。

📄 摘要(原文)

Prompt tuning has emerged as a key technique for adapting large pre-trained Decision Transformers (DTs) in offline Reinforcement Learning (RL), particularly in multi-task and few-shot settings. The Prompting Decision Transformer (PDT) enables task generalization via trajectory prompts sampled uniformly from expert demonstrations -- without accounting for prompt informativeness. In this work, we propose a bandit-based prompt-tuning method that learns to construct optimal trajectory prompts from demonstration data at inference time. We devise a structured bandit architecture operating in the trajectory prompt space, achieving linear rather than combinatorial scaling with prompt size. Additionally, we show that the pre-trained PDT itself can serve as a powerful feature extractor for the bandit, enabling efficient reward modeling across various environments. We theoretically establish regret bounds and demonstrate empirically that our method consistently enhances performance across a wide range of tasks, high-dimensional environments, and out-of-distribution scenarios, outperforming existing baselines in prompt tuning.