BOND: Aligning LLMs with Best-of-N Distillation

📄 arXiv: 2407.14622v1 📥 PDF

作者: Pier Giuseppe Sessa, Robert Dadashi, Léonard Hussenot, Johan Ferret, Nino Vieillard, Alexandre Ramé, Bobak Shariari, Sarah Perrin, Abe Friesen, Geoffrey Cideron, Sertan Girgin, Piotr Stanczyk, Andrea Michi, Danila Sinopalnikov, Sabela Ramos, Amélie Héliou, Aliaksei Severyn, Matt Hoffman, Nikola Momchev, Olivier Bachem

分类: cs.LG, cs.AI, cs.CL

发布日期: 2024-07-19


💡 一句话要点

提出BOND算法,通过模仿Best-of-N采样提升大语言模型性能,降低推理计算开销。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 大语言模型 强化学习 人类反馈 蒸馏学习 分布匹配 Best-of-N采样 抽象摘要

📋 核心要点

  1. 现有RLHF方法驱动了大语言模型质量和安全性,但推理时Best-of-N采样策略计算开销巨大。
  2. BOND算法通过分布匹配,使模型生成分布逼近Best-of-N,从而在推理时无需多次采样。
  3. 实验表明,BOND算法在抽象摘要任务和Gemma模型上,性能优于其他RLHF算法。

📝 摘要(中文)

本文提出了一种名为Best-of-N Distillation (BOND) 的新型RLHF算法。该算法旨在模仿Best-of-N采样,但避免其在推理时产生的大量计算开销。具体来说,BOND是一种分布匹配算法,它强制策略生成的分布更接近Best-of-N分布。我们使用Jeffreys散度(前向和后向KL散度的线性组合)来平衡模式覆盖和模式寻求行为,并推导出一种利用移动锚点的迭代公式以提高效率。通过在抽象摘要和Gemma模型上的实验,我们证明了该方法的有效性和几个设计选择。将Gemma策略与BOND对齐,在多个基准测试中优于其他RLHF算法。

🔬 方法详解

问题定义:论文旨在解决RLHF训练的大语言模型在推理时计算开销过大的问题。Best-of-N采样虽然能提升生成质量,但需要多次采样并选择最佳结果,导致推理成本显著增加。现有RLHF方法难以在推理效率和生成质量之间取得平衡。

核心思路:BOND的核心思路是通过蒸馏学习,让模型学习Best-of-N采样的输出分布,从而在推理时只需一次采样就能获得接近Best-of-N的效果。通过分布匹配,使得模型的生成分布尽可能接近Best-of-N的输出分布,从而避免了推理时多次采样的需求。

技术框架:BOND算法是一个迭代的训练过程,主要包含以下几个阶段:1) 使用当前策略生成N个候选样本;2) 使用奖励模型对这些样本进行评分,选择最佳样本;3) 使用Jeffreys散度作为损失函数,优化策略模型,使其生成分布更接近最佳样本的分布。Jeffreys散度用于平衡模式覆盖和模式寻求行为。同时,使用移动锚点来提高训练效率。

关键创新:BOND算法的关键创新在于使用分布匹配的方式来模仿Best-of-N采样,从而在推理时避免了多次采样带来的计算开销。此外,使用Jeffreys散度平衡模式覆盖和模式寻求,以及使用移动锚点提高训练效率也是重要的创新点。与现有方法的本质区别在于,BOND不是直接优化奖励函数,而是学习目标分布。

关键设计:BOND算法的关键设计包括:1) 使用Jeffreys散度作为损失函数,其是前向KL散度和后向KL散度的线性组合,权重系数需要仔细调整以平衡模式覆盖和模式寻求;2) 使用移动锚点,即在每次迭代中使用上一次迭代的策略模型作为锚点,可以加速训练过程;3) Best-of-N的N值,需要根据具体任务和模型大小进行调整,以获得最佳性能。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,在抽象摘要任务和Gemma模型上,BOND算法优于其他RLHF算法。具体来说,BOND在多个基准测试中取得了更高的ROUGE分数,表明其生成的摘要质量更高。此外,BOND算法在推理时只需要一次采样,显著降低了计算开销,使得其在实际应用中更具优势。

🎯 应用场景

BOND算法可广泛应用于需要高质量文本生成的大语言模型应用场景,例如:智能客服、文本摘要、机器翻译、内容创作等。该算法降低了推理计算开销,使得在资源受限的环境中部署高性能大语言模型成为可能。未来,BOND算法可以扩展到其他模态,例如图像和音频生成,提升生成质量和效率。

📄 摘要(原文)

Reinforcement learning from human feedback (RLHF) is a key driver of quality and safety in state-of-the-art large language models. Yet, a surprisingly simple and strong inference-time strategy is Best-of-N sampling that selects the best generation among N candidates. In this paper, we propose Best-of-N Distillation (BOND), a novel RLHF algorithm that seeks to emulate Best-of-N but without its significant computational overhead at inference time. Specifically, BOND is a distribution matching algorithm that forces the distribution of generations from the policy to get closer to the Best-of-N distribution. We use the Jeffreys divergence (a linear combination of forward and backward KL) to balance between mode-covering and mode-seeking behavior, and derive an iterative formulation that utilizes a moving anchor for efficiency. We demonstrate the effectiveness of our approach and several design choices through experiments on abstractive summarization and Gemma models. Aligning Gemma policies with BOND outperforms other RLHF algorithms by improving results on several benchmarks.