Prompt Baking

📄 arXiv: 2409.13697v1 📥 PDF

作者: Aman Bhargava, Cameron Witkowski, Alexander Detkov, Matt Thomson

分类: cs.CL, cs.AI

发布日期: 2024-09-04

备注: 25 pages, 8 figures


💡 一句话要点

Prompt Baking:将Prompt信息烘焙到LLM权重中,提升零样本性能并更新模型知识。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: Prompt工程 模型微调 知识蒸馏 持续学习 自监督学习

📋 核心要点

  1. 大型语言模型行为调整主要依赖Prompt和权重更新,但Prompt易忘且权重更新成本高昂。
  2. Prompt Baking通过最小化KL散度,将Prompt信息编码到LLM权重中,实现Prompt的持久化。
  3. 实验表明,Prompt Baking能有效提升零样本性能、更新模型知识,并缓解长序列中的Prompt遗忘问题。

📝 摘要(中文)

本文提出了一种名为“Prompt Baking”的技术,旨在将Prompt信息烘焙到大型语言模型(LLM)的权重中。该方法将Prompt u和初始权重θ转换为新的权重θ_u,使得“烘焙”后的LLM的行为类似于原始的、经过Prompt引导的LLM。在数学上,该方法最小化P_θ(· | u)和P_{θ_u}(·)之间的KL散度,其中P是LLM在token序列上的概率分布。实验结果表明,Prompt可以很容易地烘焙到权重更新中。烘焙思维链Prompt可以提高GSM8K、ASDiv、MBPP、ARC-Easy、ARC-Challenge和CommonsenseQA基准上的零样本性能。烘焙新闻标题可以直接更新LLM的知识。烘焙指令和角色可以缓解长序列中的“Prompt遗忘”问题。此外,提前停止烘焙可以创建“半烘焙”模型,从而连续地调整Prompt强度。烘焙后的模型保留了对进一步Prompt和烘焙的敏感性,包括使用烘焙的Prompt进行重新Prompt。令人惊讶的是,重新Prompt的模型在指令遵循以及数学推理和编码基准测试中产生了进一步的性能提升。将重新Prompt和重新烘焙推向极致,产生了一种迭代自改进的形式,称为Prompt Pursuit,并且在指令遵循方面的初步结果显示出显著的性能提升。最后,讨论了对AI安全、持续模型更新、增强基于LLM的代理的实时学习能力以及生成更稳定的AI角色的影响。

🔬 方法详解

问题定义:现有大型语言模型(LLM)的行为调整主要依赖于Prompt工程和模型微调。Prompt工程简单易用,但存在Prompt易遗忘、对长序列效果不佳等问题。模型微调虽然能实现更持久的行为改变,但需要大量数据和计算资源,成本高昂。因此,如何以更高效的方式将Prompt的优势融入模型权重,是一个亟待解决的问题。

核心思路:Prompt Baking的核心思想是将Prompt信息“烘焙”到LLM的权重中,使其在没有显式Prompt的情况下也能表现出与Prompt引导下相似的行为。通过这种方式,可以实现Prompt的持久化,避免Prompt遗忘问题,并降低推理成本。

技术框架:Prompt Baking的技术框架主要包括以下几个步骤:1. 选择一个预训练的LLM,并确定要烘焙的Prompt。2. 使用Prompt引导LLM生成输出。3. 通过优化LLM的权重,使得LLM在没有Prompt的情况下,生成的输出尽可能接近Prompt引导下的输出。具体而言,通过最小化P_θ(· | u)和P_{θ_u}(·)之间的KL散度来实现,其中P_θ(· | u)是原始LLM在Prompt u下的token序列概率分布,P_{θ_u}(·)是烘焙后的LLM的token序列概率分布。

关键创新:Prompt Baking的关键创新在于它提供了一种将Prompt信息直接嵌入到模型权重中的方法。与传统的Prompt工程相比,Prompt Baking避免了Prompt遗忘问题,并降低了推理时的计算成本。与模型微调相比,Prompt Baking需要的训练数据更少,训练过程更高效。此外,Prompt Pursuit通过迭代地重新Prompt和重新烘焙,实现了模型的自改进。

关键设计:Prompt Baking的关键设计包括:1. 使用KL散度作为损失函数,以确保烘焙后的模型行为与Prompt引导下的模型行为尽可能相似。2. 允许提前停止烘焙过程,以创建“半烘焙”模型,从而实现对Prompt强度的连续控制。3. 支持对烘焙后的模型进行进一步的Prompt和烘焙,以实现更复杂的行为调整和模型自改进。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,Prompt Baking在多个基准测试中取得了显著的性能提升。例如,烘焙思维链Prompt可以提高GSM8K、ASDiv、MBPP、ARC-Easy、ARC-Challenge和CommonsenseQA基准上的零样本性能。此外,Prompt Pursuit在指令遵循方面取得了显著的性能提升,展示了模型自改进的潜力。

🎯 应用场景

Prompt Baking具有广泛的应用前景,包括:AI安全(通过烘焙安全指令来约束模型行为)、持续模型更新(通过烘焙新知识来不断更新模型)、增强LLM代理的实时学习能力(通过烘焙经验来提升代理的决策能力)、以及生成更稳定的AI角色(通过烘焙角色设定来确保角色一致性)。

📄 摘要(原文)

Two primary ways to change LLM behavior are prompting and weight updates (e.g., fine-tuning). Prompting LLMs is simple and effective, specifying the desired changes explicitly in natural language, whereas weight updates provide more expressive and permanent behavior changes, specified implicitly via training on large datasets. We present a technique for "baking" prompts into the weights of an LLM. Prompt Baking converts a prompt $u$ and initial weights $θ$ to a new set of weights $θ_u$ such that new "baked" LLM behaves like the original prompted LLM. Mathematically, we minimize the KL divergence between $P_θ(\cdot | u)$ and $P_{θ_u}(\cdot)$, where $P$ is the LLM's probability distribution over token sequences. Across all our experiments, we find prompts can be readily baked into weight updates. Baking chain-of-thought prompts improves zero-shot performance on GSM8K, ASDiv, MBPP, ARC-Easy, ARC-Challenge, and CommonsenseQA benchmarks. Baking news headlines directly updates an LLM's knowledge. And baking instructions & personas alleviates "prompt forgetting" over long sequences. Furthermore, stopping baking early creates "half-baked" models, continuously scaling prompt strength. Baked models retain their sensitivity to further prompting and baking, including re-prompting with the baked-in prompt. Surprisingly, the re-prompted models yield further performance gains in instruction following, as well as math reasoning and coding benchmarks. Taking re-prompting and re-baking to the limit yields a form of iterative self-improvement we call Prompt Pursuit, and preliminary results on instruction following exhibit dramatic performance gains. Finally, we discuss implications for AI safety, continuous model updating, enhancing real-time learning capabilities in LLM-based agents, and generating more stable AI personas.