PEARL: Towards Permutation-Resilient LLMs

📄 arXiv: 2502.14628v1 📥 PDF

作者: Liang Chen, Li Shen, Yang Deng, Xiaoyan Zhao, Bin Liang, Kam-Fai Wong

分类: cs.LG, cs.CL

发布日期: 2025-02-20

备注: ICLR 2025


💡 一句话要点

提出PEARL框架,提升大语言模型在上下文学习中对输入排列的鲁棒性

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 上下文学习 大语言模型 鲁棒性 排列不变性 分布鲁棒优化

📋 核心要点

  1. 现有大语言模型的上下文学习能力易受输入示例顺序的影响,导致预测结果不稳定,存在安全隐患。
  2. PEARL框架基于分布鲁棒优化,通过对抗训练提升模型对输入排列的鲁棒性,核心是置换提议网络生成最具挑战性的排列。
  3. 实验表明,PEARL能有效缓解排列攻击,并在多样本和长上下文场景下取得显著性能提升,最高可达40%。

📝 摘要(中文)

大型语言模型(LLMs)的上下文学习(ICL)能力使其能够利用提供的演示来执行具有挑战性的任务。然而,ICL对演示的排序高度敏感,导致预测不稳定。本文表明,这种脆弱性可被利用来设计一种自然的攻击——模型提供商难以检测——通过简单地排列演示,在LLaMA-3上达到近80%的成功率。现有的缓解方法主要依赖于后处理,未能增强模型对输入排列的内在鲁棒性,引发了对LLM安全性和可靠性的担忧。为了解决这个问题,我们提出了一种基于分布鲁棒优化(DRO)的置换鲁棒学习(PEARL)框架,该框架针对最坏情况的输入排列优化模型性能。具体来说,PEARL由一个置换提议网络(P-Net)和LLM组成。P-Net通过将置换视为一个最优传输问题来生成最具挑战性的置换,该问题使用熵约束的Sinkhorn算法解决。通过极小极大优化,P-Net和LLM相互迭代优化,逐步提高LLM的鲁棒性。在合成预训练和真实指令调优任务上的实验表明,PEARL有效地缓解了置换攻击并提高了性能。值得注意的是,尽管PEARL在较少的样本和较短的上下文中进行训练,但在扩展到多样本和长上下文场景时,PEARL实现了高达40%的性能提升,突出了其效率和泛化能力。

🔬 方法详解

问题定义:大语言模型在上下文学习(ICL)中,对输入示例的排列顺序非常敏感。即使输入示例的内容不变,仅仅改变它们的顺序,模型的预测结果也会发生显著变化。这种脆弱性使得攻击者可以通过精心设计的排列方式来误导模型,从而降低模型的可靠性和安全性。现有的缓解方法主要集中在后处理阶段,例如对多个排列结果进行平均,但这些方法无法从根本上提高模型对排列的鲁棒性。

核心思路:PEARL的核心思路是利用分布鲁棒优化(DRO)的思想,训练一个对输入排列具有鲁棒性的语言模型。具体来说,PEARL通过对抗训练的方式,让模型在训练过程中接触到各种可能的输入排列,并学习在最坏情况下的排列下也能保持良好的性能。这样,模型就能更好地适应不同的输入顺序,从而提高其鲁棒性。

技术框架:PEARL框架包含两个主要组成部分:置换提议网络(P-Net)和语言模型(LLM)。P-Net负责生成最具挑战性的输入排列,而LLM则负责在这些排列下进行学习和预测。整个训练过程是一个极小极大优化过程,P-Net试图找到使LLM性能最差的排列,而LLM则试图在这些排列下优化自身的性能。通过这种对抗训练,LLM逐渐变得对输入排列不敏感。

关键创新:PEARL最重要的创新在于它将输入排列的鲁棒性问题转化为一个分布鲁棒优化问题,并通过对抗训练的方式来解决。与现有的后处理方法不同,PEARL直接在模型训练阶段就考虑了输入排列的影响,从而从根本上提高了模型的鲁棒性。此外,P-Net的设计也很有创新性,它将排列生成问题建模为一个最优传输问题,并使用Sinkhorn算法来高效地求解。

关键设计:P-Net将排列生成建模为最优传输问题,使用Sinkhorn算法求解,并加入熵约束以保证解的平滑性。损失函数采用极小极大损失,P-Net的目标是最大化损失,LLM的目标是最小化损失。训练过程中,P-Net和LLM交替优化。此外,论文还探索了不同的P-Net结构和训练策略,以进一步提高PEARL的性能。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,PEARL能够有效缓解排列攻击,并在合成预训练和真实指令调优任务上都取得了显著的性能提升。特别是在多样本和长上下文场景下,PEARL的性能提升高达40%,这表明PEARL具有良好的泛化能力和效率。此外,PEARL在较少的样本和较短的上下文中进行训练,也能在更复杂的场景中取得优异表现。

🎯 应用场景

PEARL框架可应用于各种需要上下文学习的大语言模型应用场景,例如问答系统、文本摘要、代码生成等。通过提高模型对输入排列的鲁棒性,可以增强这些应用的安全性和可靠性,防止恶意攻击者通过操纵输入顺序来误导模型。此外,PEARL还可以提高模型在实际应用中的泛化能力,使其能够更好地适应不同的用户输入习惯。

📄 摘要(原文)

The in-context learning (ICL) capability of large language models (LLMs) enables them to perform challenging tasks using provided demonstrations. However, ICL is highly sensitive to the ordering of demonstrations, leading to instability in predictions. This paper shows that this vulnerability can be exploited to design a natural attack - difficult for model providers to detect - that achieves nearly 80% success rate on LLaMA-3 by simply permuting the demonstrations. Existing mitigation methods primarily rely on post-processing and fail to enhance the model's inherent robustness to input permutations, raising concerns about safety and reliability of LLMs. To address this issue, we propose Permutation-resilient learning (PEARL), a novel framework based on distributionally robust optimization (DRO), which optimizes model performance against the worst-case input permutation. Specifically, PEARL consists of a permutation-proposal network (P-Net) and the LLM. The P-Net generates the most challenging permutations by treating it as an optimal transport problem, which is solved using an entropy-constrained Sinkhorn algorithm. Through minimax optimization, the P-Net and the LLM iteratively optimize against each other, progressively improving the LLM's robustness. Experiments on synthetic pre-training and real-world instruction tuning tasks demonstrate that PEARL effectively mitigates permutation attacks and enhances performance. Notably, despite being trained on fewer shots and shorter contexts, PEARL achieves performance gains of up to 40% when scaled to many-shot and long-context scenarios, highlighting its efficiency and generalization capabilities.