Inference-Aware Fine-Tuning for Best-of-N Sampling in Large Language Models

📄 arXiv: 2412.15287v2 📥 PDF

作者: Yinlam Chow, Guy Tennenholtz, Izzeddin Gur, Vincent Zhuang, Bo Dai, Sridhar Thiagarajan, Craig Boutilier, Rishabh Agarwal, Aviral Kumar, Aleksandra Faust

分类: cs.CL, cs.AI, cs.LG

发布日期: 2024-12-18 (更新: 2025-11-25)


💡 一句话要点

提出推理感知微调方法,优化大语言模型Best-of-N采样策略性能

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 大语言模型 推理感知微调 Best-of-N采样 模仿学习 强化学习 元策略 代码生成 数学问题

📋 核心要点

  1. 现有大语言模型推理效率不高,未能充分利用推理时计算资源以提升性能。
  2. 提出推理感知微调,针对Best-of-N策略优化模型,学习探索-利用的元策略。
  3. 实验表明,该方法显著提升了模型在数学问题和代码生成任务上的性能。

📝 摘要(中文)

本文提出了一种新颖的推理感知微调范式,该范式以直接优化推理时策略的性能的方式对模型进行微调。我们使用简单而有效的Best-of-N (BoN) 推理策略来研究这种范式,其中验证器从一组LLM生成的响应中选择最佳响应。我们设计了第一个用于BoN感知微调的模仿学习和强化学习(RL)方法,克服了BoN中具有挑战性的、不可微的argmax算子。实验表明,我们的BoN感知模型隐式地学习了一种元策略,该策略将最佳响应与可能更适合测试时输入的多样化响应交错,这一过程让人联想到RL中的探索-利用权衡。实验结果表明,BoN感知微调在提高性能和推理时计算效率方面是有效的。特别是,我们的方法将Gemma 2B在Hendrycks MATH上的Bo32性能从26.8%提高到30.8%,pass@32从60.0%提高到67.0%,以及HumanEval上的pass@16从61.6%提高到67.1%。

🔬 方法详解

问题定义:论文旨在解决如何更有效地利用推理时计算资源,提升大语言模型在Best-of-N采样策略下的性能。现有方法通常独立优化模型训练和推理策略,忽略了两者之间的相互影响,导致推理时性能未达到最优。现有方法难以处理Best-of-N策略中argmax操作的不可微性,阻碍了端到端优化。

核心思路:论文的核心思路是进行推理感知微调,即在模型微调阶段,直接优化推理时所采用的策略(Best-of-N)。通过模仿学习和强化学习,使模型学习到一种元策略,能够根据输入自适应地平衡最佳响应和多样化响应,从而更好地适应测试时输入。

技术框架:整体框架包括预训练的大语言模型、Best-of-N采样策略以及微调模块。微调模块采用模仿学习或强化学习方法,以优化Best-of-N策略的性能。具体流程为:首先,使用大语言模型生成N个候选响应;然后,使用验证器(可以是另一个模型或人工标注)对这些响应进行评估;最后,根据评估结果,使用模仿学习或强化学习方法对模型进行微调,使其能够更好地生成高质量的候选响应。

关键创新:最重要的技术创新点在于提出了推理感知微调范式,将模型训练和推理策略优化相结合。通过模仿学习和强化学习,克服了Best-of-N策略中argmax操作的不可微性,实现了端到端的优化。模型学习到的元策略能够自适应地平衡最佳响应和多样化响应,提高了模型的泛化能力。

关键设计:在模仿学习中,使用验证器选择的最佳响应作为目标,训练模型生成高质量的候选响应。在强化学习中,使用验证器的评估结果作为奖励信号,训练模型生成能够获得更高奖励的候选响应。具体损失函数的设计需要考虑argmax操作的不可微性,例如可以使用Gumbel-Softmax技巧进行近似。参数设置方面,需要根据具体的任务和数据集进行调整,例如Best-of-N中的N值,以及模仿学习和强化学习中的学习率、奖励函数等。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,该方法显著提升了Gemma 2B模型在Hendrycks MATH和HumanEval数据集上的性能。在Hendrycks MATH数据集上,Bo32性能从26.8%提高到30.8%,pass@32从60.0%提高到67.0%。在HumanEval数据集上,pass@16从61.6%提高到67.1%。这些结果表明,推理感知微调能够有效地提升大语言模型的性能。

🎯 应用场景

该研究成果可应用于各种需要高质量文本生成的场景,例如机器翻译、文本摘要、对话系统、代码生成等。通过推理感知微调,可以显著提升大语言模型在这些任务上的性能和效率,降低计算成本,并提高用户体验。未来,该方法有望推广到其他推理策略和模型架构,进一步推动大语言模型的发展。

📄 摘要(原文)

Recent studies have indicated that effectively utilizing inference-time compute is crucial for attaining better performance from large language models (LLMs). In this work, we propose a novel inference-aware fine-tuning paradigm, in which the model is fine-tuned in a manner that directly optimizes the performance of the inference-time strategy. We study this paradigm using the simple yet effective Best-of-N (BoN) inference strategy, in which a verifier selects the best out of a set of LLM-generated responses. We devise the first imitation learning and reinforcement learning~(RL) methods for BoN-aware fine-tuning, overcoming the challenging, non-differentiable argmax operator within BoN. We empirically demonstrate that our BoN-aware models implicitly learn a meta-strategy that interleaves best responses with more diverse responses that might be better suited to a test-time input -- a process reminiscent of the exploration-exploitation trade-off in RL. Our experiments demonstrate the effectiveness of BoN-aware fine-tuning in terms of improved performance and inference-time compute. In particular, we show that our methods improve the Bo32 performance of Gemma 2B on Hendrycks MATH from 26.8% to 30.8%, and pass@32 from 60.0% to 67.0%, as well as the pass@16 on HumanEval from 61.6% to 67.1%.