WPN: An Unlearning Method Based on N-pair Contrastive Learning in Language Models
作者: Guitao Chen, Yunshen Wang, Hongye Sun, Guang Chen
分类: cs.CL, cs.IR
发布日期: 2024-08-18
备注: ECAI 2024
💡 一句话要点
提出WPN:一种基于N-pair对比学习的语言模型不可学习方法,用于消除有害知识。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 语言模型 不可学习 对比学习 有害内容过滤 模型安全
📋 核心要点
- 现有基于梯度上升的不可学习方法会显著降低语言模型的性能,这是核心问题。
- WPN方法利用位置加权平均池化和N-pair对比学习,旨在修改模型输出分布,消除有害输出。
- 实验表明,WPN能有效降低有害响应比例,同时保持模型在常见基准上的性能稳定。
📝 摘要(中文)
生成式语言模型(LMs)具有诸多优势,但由于在预训练期间获得的有害知识,可能产生不适当或有害的输出。这种知识通常表现为不良的对应关系,例如“有害提示”导致“有害输出”,我们的研究旨在通过不可学习技术来缓解这种情况。然而,现有的基于梯度上升的不可学习方法会显著损害LMs的性能。为了解决这个问题,我们提出了一种名为加权位置N-pair(WPN)学习的新方法,该方法利用n-pair对比学习框架内的位置加权平均池化。WPN旨在通过消除特定的有害输出(例如,用中性响应替换有毒响应)来修改LMs的输出分布,从而将模型的行为从“有害提示-有害输出”转变为“有害提示-无害响应”。在OPT和GPT-NEO LMs上的实验表明,WPN有效地降低了有害响应的比例,实现了高达95.8%的无害率,同时在九个常见基准上保持了稳定的性能(平均降幅小于2%)。此外,我们提供了经验证据来证明WPN在泛化性和鲁棒性方面削弱有害对应关系的能力,这些能力在分布外测试集和对抗攻击下进行了评估。
🔬 方法详解
问题定义:论文旨在解决语言模型由于预训练数据中的有害知识而产生不当或有害输出的问题。现有基于梯度上升的不可学习方法在消除有害知识的同时,会显著损害语言模型的整体性能,导致模型能力下降。
核心思路:论文的核心思路是通过修改语言模型的输出分布,将有害提示引导至无害响应,从而消除“有害提示-有害输出”的对应关系。WPN方法通过对比学习,鼓励模型对有害提示产生更安全、更中性的输出,同时避免过度调整模型参数,从而保持模型性能。
技术框架:WPN方法基于N-pair对比学习框架。首先,对语言模型的输出进行位置加权平均池化,得到每个提示的嵌入表示。然后,构建N-pair对比损失函数,该损失函数鼓励模型将有害提示的嵌入表示与无害响应的嵌入表示拉近,同时推远有害响应的嵌入表示。通过最小化该损失函数,模型学习到一种新的输出分布,从而减少有害输出的产生。
关键创新:WPN的关键创新在于结合了位置加权平均池化和N-pair对比学习。位置加权平均池化能够更好地捕捉语言模型输出中的关键信息,而N-pair对比学习能够更有效地学习到有害提示和无害响应之间的关系。此外,WPN方法还通过调整损失函数的权重,平衡了不可学习和模型性能之间的trade-off。
关键设计:WPN方法使用位置加权平均池化来提取语言模型输出的嵌入表示。具体来说,每个位置的权重是根据其在序列中的位置计算的。N-pair对比损失函数的设计如下:对于每个有害提示,选择一个无害响应作为正样本,并选择其他有害响应作为负样本。损失函数的目标是最小化正样本之间的距离,同时最大化负样本之间的距离。损失函数的权重可以根据有害提示的严重程度进行调整。
🖼️ 关键图片
📊 实验亮点
实验结果表明,WPN方法在OPT和GPT-NEO语言模型上实现了高达95.8%的无害率,同时在九个常见基准测试中,模型性能平均下降小于2%。此外,WPN方法在分布外测试集和对抗攻击下表现出良好的泛化性和鲁棒性,证明了其有效性。
🎯 应用场景
WPN方法可应用于各种生成式语言模型,以提高其安全性和可靠性。例如,可以用于过滤社交媒体平台上的有害内容,防止聊天机器人生成不当回复,以及提高代码生成模型的安全性。该研究对于构建更负责任和可信赖的人工智能系统具有重要意义。
📄 摘要(原文)
Generative language models (LMs) offer numerous advantages but may produce inappropriate or harmful outputs due to the harmful knowledge acquired during pre-training. This knowledge often manifests as undesirable correspondences, such as "harmful prompts" leading to "harmful outputs," which our research aims to mitigate through unlearning techniques.However, existing unlearning methods based on gradient ascent can significantly impair the performance of LMs. To address this issue, we propose a novel approach called Weighted Positional N-pair (WPN) Learning, which leverages position-weighted mean pooling within an n-pair contrastive learning framework. WPN is designed to modify the output distribution of LMs by eliminating specific harmful outputs (e.g., replacing toxic responses with neutral ones), thereby transforming the model's behavior from "harmful prompt-harmful output" to "harmful prompt-harmless response".Experiments on OPT and GPT-NEO LMs show that WPN effectively reduces the proportion of harmful responses, achieving a harmless rate of up to 95.8\% while maintaining stable performance on nine common benchmarks (with less than 2\% degradation on average). Moreover, we provide empirical evidence to demonstrate WPN's ability to weaken the harmful correspondences in terms of generalizability and robustness, as evaluated on out-of-distribution test sets and under adversarial attacks.