Prompt Diffusion Robustifies Any-Modality Prompt Learning

📄 arXiv: 2410.20164v1 📥 PDF

作者: Yingjun Du, Gaowen Liu, Yuzhang Shang, Yuguang Yao, Ramana Kompella, Cees G. M. Snoek

分类: cs.LG, cs.CV

发布日期: 2024-10-26

备注: Under review


💡 一句话要点

提出Prompt Diffusion,提升任意模态Prompt Learning的鲁棒性。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: Prompt Learning 扩散模型 零样本学习 少样本学习 领域泛化 分布偏移 多模态学习

📋 核心要点

  1. 现有Prompt Learning方法依赖固定Prompt,在分布偏移下泛化能力不足,是核心问题。
  2. Prompt Diffusion通过扩散模型学习Prompt空间中的生成过程,为每个样本定制Prompt。
  3. 实验表明,Prompt Diffusion能显著提升Prompt Learning在多种泛化任务中的鲁棒性。

📝 摘要(中文)

本文提出了一种名为Prompt Diffusion的方法,旨在提升基于Prompt的基础模型在零样本和少样本学习中的泛化能力。传统方法使用固定Prompt,容易受到分布偏移的影响。Prompt Diffusion利用扩散模型逐步优化Prompt,为每个样本生成定制化的Prompt。具体而言,首先优化一组Prompt,为每个样本获得过拟合的Prompt。然后,在Prompt空间内训练一个Prompt扩散模型,学习从随机Prompt到过拟合Prompt的生成过程。在推理阶段,仅使用训练好的Prompt扩散模型,从随机Prompt逐步生成定制化的Prompt。Prompt Diffusion具有通用性、灵活性和模态无关性,可以无缝嵌入到现有的文本、视觉或多模态Prompt Learning方法中。该扩散模型采用基于ODE的快速采样策略,仅需五步即可优化测试样本的Prompt,在性能提升和计算效率之间取得了良好的平衡。在15个数据集上的分类任务中,Prompt Diffusion在base-to-new泛化、跨数据集泛化和领域泛化方面均表现出更强的鲁棒性。

🔬 方法详解

问题定义:现有的基于Prompt的学习方法,特别是零样本和少样本学习,依赖于预定义的、固定的Prompt。这些固定Prompt在训练数据分布与测试数据分布存在差异时,性能会显著下降,即缺乏鲁棒性。论文旨在解决Prompt Learning在分布偏移下的泛化能力问题。

核心思路:论文的核心思路是利用扩散模型学习Prompt空间中的数据分布,从而能够根据每个输入样本生成定制化的Prompt。通过将Prompt视为一个连续的变量,并使用扩散模型学习从随机噪声到特定Prompt的生成过程,模型能够更好地适应不同的输入样本,从而提高泛化能力。

技术框架:Prompt Diffusion的整体框架包含以下几个主要阶段:1) Prompt优化阶段:首先,针对每个训练样本,优化一组Prompt,使其能够很好地拟合该样本。这些Prompt可以被认为是该样本的“过拟合”Prompt。2) Prompt扩散模型训练阶段:利用扩散模型学习从随机Prompt到过拟合Prompt的生成过程。扩散模型通过逐步添加噪声到Prompt,然后学习逆向过程,即从噪声中恢复原始Prompt。3) 推理阶段:在测试时,从随机Prompt开始,利用训练好的扩散模型逐步生成定制化的Prompt,然后使用该Prompt进行分类。

关键创新:Prompt Diffusion的关键创新在于将扩散模型引入到Prompt Learning中,从而能够动态地生成Prompt,而不是依赖于固定的Prompt。这种方法能够更好地适应不同的输入样本,从而提高泛化能力。此外,使用基于ODE的快速采样策略,在保证性能的同时,提高了计算效率。

关键设计:Prompt Diffusion使用扩散模型来学习Prompt空间中的数据分布。扩散模型采用基于ODE的采样策略,可以在较少的步骤内生成高质量的Prompt。具体来说,论文使用DDIM(Denoising Diffusion Implicit Models)的采样方法,仅需5步即可完成Prompt的生成,在性能和效率之间取得平衡。损失函数方面,扩散模型的训练目标是最小化重构误差,即最小化从噪声中恢复原始Prompt的误差。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,Prompt Diffusion在多个数据集上显著提升了Prompt Learning的性能。例如,在base-to-new泛化任务中,Prompt Diffusion相较于传统方法取得了明显的性能提升。此外,Prompt Diffusion在跨数据集泛化和领域泛化方面也表现出更强的鲁棒性,证明了其在处理分布偏移问题上的有效性。

🎯 应用场景

Prompt Diffusion可广泛应用于自然语言处理、计算机视觉和多模态学习等领域,尤其是在零样本和少样本学习场景下。该方法能够提升模型在分布偏移下的鲁棒性,使其在实际应用中更具价值。例如,可以用于跨领域图像分类、文本分类等任务,并有望推动Prompt Learning在更多实际场景中的应用。

📄 摘要(原文)

Foundation models enable prompt-based classifiers for zero-shot and few-shot learning. Nonetheless, the conventional method of employing fixed prompts suffers from distributional shifts that negatively impact generalizability to unseen samples. This paper introduces prompt diffusion, which uses a diffusion model to gradually refine the prompts to obtain a customized prompt for each sample. Specifically, we first optimize a collection of prompts to obtain over-fitted prompts per sample. Then, we propose a prompt diffusion model within the prompt space, enabling the training of a generative transition process from a random prompt to its overfitted prompt. As we cannot access the label of a test image during inference, our model gradually generates customized prompts solely from random prompts using our trained, prompt diffusion. Our prompt diffusion is generic, flexible, and modality-agnostic, making it a simple plug-and-play module seamlessly embedded into existing prompt learning methods for textual, visual, or multi-modal prompt learning. Our diffusion model uses a fast ODE-based sampling strategy to optimize test sample prompts in just five steps, offering a good trade-off between performance improvement and computational efficiency. For all prompt learning methods tested, adding prompt diffusion yields more robust results for base-to-new generalization, cross-dataset generalization, and domain generalization in classification tasks tested over 15 diverse datasets.