BiLD: Bi-directional Logits Difference Loss for Large Language Model Distillation

📄 arXiv: 2406.13555v3 📥 PDF

作者: Minchong Li, Feng Zhou, Xiaohui Song

分类: cs.CL, cs.AI

发布日期: 2024-06-19 (更新: 2025-02-18)

备注: COLING 2025


💡 一句话要点

提出BiLD损失,通过双向Logits差异蒸馏提升大语言模型性能。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 知识蒸馏 大型语言模型 Logits蒸馏 长尾分布 模型压缩

📋 核心要点

  1. 现有logits蒸馏方法难以有效利用LLM logits的内部排序信息,且LLM logits存在更严重的长尾分布噪声。
  2. 提出BiLD损失,通过仅使用top-k logits来过滤长尾噪声,并构建logits差异来利用内部排序信息。
  3. 在13个数据集上,BiLD损失仅使用top-8 logits就超越了SFT、KL散度和多种NLP及CV领域的蒸馏方法。

📝 摘要(中文)

近年来,大型语言模型(LLMs)在各种自然语言处理(NLP)任务中展现出卓越的能力。然而,这种令人印象深刻的性能通常以参数规模的增加为代价,给广泛部署带来了重大挑战。知识蒸馏(KD)提供了一种解决方案,通过将知识从大型教师模型转移到较小的学生模型。在本文中,我们探索了LLMs在logit层面的特定任务蒸馏。我们的研究表明,微调后的LLMs的logits表现出比视觉模型更极端的长尾分布,长尾中隐藏的“噪声”会影响蒸馏性能。此外,现有的logits蒸馏方法通常难以有效地利用logits的内部排序信息。为了解决这些问题,我们提出了双向Logits差异(BiLD)损失。BiLD损失仅利用top-$k$的教师和学生logits来过滤掉长尾噪声,并通过构建logits差异来利用内部logits排序信息。为了评估BiLD损失,我们使用两种类型的LLMs在13个数据集上进行了全面的实验。我们的结果表明,仅使用top-8 logits的BiLD损失优于监督微调(SFT)、原始KL损失以及来自NLP和CV领域的其他五种蒸馏方法。

🔬 方法详解

问题定义:现有logits蒸馏方法在应用于大型语言模型时,存在两个主要问题。一是大型语言模型微调后的logits分布呈现出比视觉模型更极端的长尾分布,长尾部分包含大量噪声,影响蒸馏效果。二是现有方法难以有效利用logits内部的排序信息,而这些排序信息对于知识传递至关重要。

核心思路:BiLD损失的核心思路是双向利用logits的差异信息,同时抑制长尾噪声。具体来说,它只关注top-k个logits,认为这些logits包含了最重要的信息,而长尾部分则被视为噪声。通过计算教师和学生模型top-k logits之间的差异,以及利用logits的排序关系构建差异,从而实现更有效的知识传递。

技术框架:BiLD损失的整体框架包括以下步骤:1) 获取教师模型和学生模型的logits;2) 分别选取教师和学生模型的top-k logits;3) 计算教师和学生模型top-k logits之间的差异;4) 利用logits的排序关系构建差异;5) 将上述差异进行加权求和,得到最终的BiLD损失。

关键创新:BiLD损失的关键创新在于其双向差异计算和长尾噪声抑制。传统的logits蒸馏方法通常只关注logits的绝对值,而忽略了logits之间的相对关系。BiLD损失通过计算logits的差异,显式地利用了logits的排序信息。同时,通过只关注top-k logits,有效地抑制了长尾噪声的影响。

关键设计:BiLD损失的关键设计包括:1) top-k的选择:论文实验中,k=8时效果最佳。2) 差异计算方式:论文采用了绝对值差异和排序差异两种方式。3) 损失函数权重:论文通过实验调整了不同差异项的权重,以达到最佳性能。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,BiLD损失在13个数据集上均取得了优异的性能。仅使用top-8 logits的BiLD损失,就超越了监督微调(SFT)、原始KL损失以及来自NLP和CV领域的其他五种蒸馏方法。这证明了BiLD损失在大型语言模型蒸馏方面的有效性和优越性。

🎯 应用场景

BiLD损失可应用于各种需要知识蒸馏的大型语言模型场景,例如模型压缩、加速推理、降低部署成本等。该方法尤其适用于资源受限的设备,例如移动设备和嵌入式系统,可以在保证模型性能的同时,显著减小模型体积和计算复杂度。未来,该方法可以进一步扩展到其他类型的模型和任务中。

📄 摘要(原文)

In recent years, large language models (LLMs) have shown exceptional capabilities across various natural language processing (NLP) tasks. However, such impressive performance often comes with the trade-off of an increased parameter size, posing significant challenges for widespread deployment. Knowledge distillation (KD) provides a solution by transferring knowledge from a large teacher model to a smaller student model. In this paper, we explore the task-specific distillation of LLMs at the logit level. Our investigation reveals that the logits of fine-tuned LLMs exhibit a more extreme long-tail distribution than those from vision models, with hidden "noise" in the long tail affecting distillation performance. Furthermore, existing logits distillation methods often struggle to effectively utilize the internal ranking information from the logits. To address these, we propose the Bi-directional Logits Difference (BiLD) loss. The BiLD loss filters out the long-tail noise by utilizing only top-$k$ teacher and student logits, and leverages the internal logits ranking information by constructing logits differences. To evaluate BiLD loss, we conduct comprehensive experiments on 13 datasets using two types of LLMs. Our results show that the BiLD loss, with only the top-8 logits, outperforms supervised fine-tuning (SFT), vanilla KL loss, and five other distillation methods from both NLP and CV fields.