Preserving Diversity in Supervised Fine-Tuning of Large Language Models
作者: Ziniu Li, Congliang Chen, Tian Xu, Zeyu Qin, Jiancong Xiao, Zhi-Quan Luo, Ruoyu Sun
分类: cs.LG, cs.AI
发布日期: 2024-08-29 (更新: 2025-04-05)
备注: accepted by ICLR 2025
💡 一句话要点
提出GEM算法,通过博弈论方法提升大语言模型有监督微调中的多样性
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 大语言模型 有监督微调 多样性 博弈论 反向KL散度 熵正则化 模型遗忘
📋 核心要点
- 现有有监督微调方法依赖交叉熵损失,导致大语言模型输出多样性降低,限制了模型探索更优解的能力。
- 论文提出基于博弈论的SFT框架,引入辅助变量调节学习过程,等价于带熵正则化的反向KL散度最小化,从而提升输出多样性。
- 实验结果表明,GEM算法在保持下游任务性能的同时,显著提升了输出多样性,并能有效缓解模型遗忘问题。
📝 摘要(中文)
大型语言模型(LLMs)通常依赖于有监督微调(SFT)来专门化于下游任务,交叉熵(CE)损失是事实上的选择。然而,CE最大化了观察数据的可能性,而没有考虑到其他的可能性。因此,CE通常会导致模型输出的多样性降低,这阻碍了需要抽样来探索更好响应的进一步发展。为了解决这个限制,本文为SFT引入了一种新的博弈论公式。在这个框架中,引入了一个辅助变量来调节学习过程。我们证明了所提出的博弈论方法与具有熵正则化的反向KL最小化问题相关联。这种正则化防止了对训练数据的过度记忆,并促进了输出多样性。为了实现这个框架,我们开发了GEM,一种新的训练算法,它通过利用LLM的一些独特属性,在计算上与CE一样高效。对3B到70B参数的预训练模型的实证研究表明,GEM在下游性能上与CE相当,同时显著提高了输出多样性。这种增加的多样性转化为聊天和代码生成任务中测试时计算扩展的性能提升。此外,我们观察到,保持输出多样性具有减轻遗忘的额外好处,因为保持多样化的输出鼓励模型在整个训练过程中保留预训练的知识。
🔬 方法详解
问题定义:现有的大语言模型有监督微调(SFT)方法,主要采用交叉熵(CE)损失函数。这种方法倾向于最大化训练数据的似然,而忽略了其他可能的输出,导致模型过度拟合训练数据,输出多样性降低。这限制了模型在需要探索性生成任务中的表现,例如对话生成和代码生成等。
核心思路:论文的核心思路是通过引入博弈论的视角,将SFT过程建模为一个博弈过程。在这个博弈过程中,模型不仅要最大化训练数据的似然,还要考虑到输出的多样性。通过引入一个辅助变量来调节学习过程,鼓励模型生成更多样化的输出。
技术框架:GEM算法的整体框架如下:首先,将SFT过程建模为一个博弈问题,引入一个辅助变量来调节学习过程。然后,证明该博弈问题等价于带熵正则化的反向KL散度最小化问题。最后,设计了一种计算高效的训练算法GEM,该算法利用LLM的特性,使得计算复杂度与CE损失相当。
关键创新:该论文的关键创新在于将博弈论引入到SFT过程中,并证明其等价于带熵正则化的反向KL散度最小化。这种方法能够有效地提升模型输出的多样性,从而改善模型在探索性生成任务中的表现。此外,GEM算法在计算效率上与CE损失相当,使其能够应用于大规模语言模型的微调。
关键设计:GEM算法的关键设计包括:1) 博弈论框架的构建,引入辅助变量调节学习过程;2) 证明博弈问题与带熵正则化的反向KL散度最小化的等价性;3) 设计计算高效的训练算法,利用LLM的特性降低计算复杂度。具体的损失函数设计未知,论文中可能包含更详细的参数设置。
🖼️ 关键图片
📊 实验亮点
实验结果表明,GEM算法在3B到70B参数的预训练模型上,能够在保持下游任务性能与CE损失相当的同时,显著提升输出多样性。在聊天和代码生成任务中,GEM算法能够通过增加测试时计算扩展来获得性能提升。此外,GEM算法还能够有效缓解模型遗忘问题。
🎯 应用场景
该研究成果可广泛应用于各种需要大语言模型进行微调的场景,尤其是在对话生成、代码生成等需要模型具备较高创造性和探索能力的领域。通过提升模型输出的多样性,可以改善用户体验,提高生成内容的质量,并促进大语言模型在更广泛领域的应用。
📄 摘要(原文)
Large Language Models (LLMs) typically rely on Supervised Fine-Tuning (SFT) to specialize in downstream tasks, with the Cross Entropy (CE) loss being the de facto choice. However, CE maximizes the likelihood of observed data without accounting for alternative possibilities. As such, CE usually leads to reduced diversity in the model's outputs, which hinders further development that requires sampling to explore better responses. To address this limitation, this paper introduces a new game-theoretic formulation for SFT. In this framework, an auxiliary variable is introduced to regulate the learning process. We prove that the proposed game-theoretic approach connects to the problem of reverse KL minimization with entropy regularization. This regularization prevents over-memorization of training data and promotes output diversity. To implement this framework, we develop GEM, a new training algorithm that is computationally efficient as CE by leveraging some unique properties of LLMs. Empirical studies of pre-trained models from 3B to 70B parameters show that GEM achieves comparable downstream performance to CE while significantly enhancing output diversity. This increased diversity translates to performance gains in test-time compute scaling for chat and code generation tasks. Moreover, we observe that preserving output diversity has the added benefit of mitigating forgetting, as maintaining diverse outputs encourages models to retain pre-trained knowledge throughout the training process.