Multi-Granularity Semantic Revision for Large Language Model Distillation

📄 arXiv: 2407.10068v1 📥 PDF

作者: Xiaoyu Liu, Yun Zhang, Wei Li, Simiao Li, Xudong Huang, Hanting Chen, Yehui Tang, Jie Hu, Zhiwei Xiong, Yunhe Wang

分类: cs.CL

发布日期: 2024-07-14


💡 一句话要点

提出多粒度语义修正方法,提升大语言模型蒸馏效果,减少生成错误。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 知识蒸馏 大语言模型 语义修正 序列生成 模型压缩

📋 核心要点

  1. 现有LLM蒸馏方法过度依赖学生模型输出,易引入生成错误,且损失函数难以对齐关键信息。
  2. 提出多粒度语义修正方法,包括序列校正重生成、分布自适应KL散度和span概率相关性约束。
  3. 实验结果表明,该方法在不同参数规模模型上均优于现有方法,有效提升蒸馏效果。

📝 摘要(中文)

知识蒸馏在压缩大型语言模型(LLMs)中起着关键作用,它可以在大型教师模型的指导下提升小型学生模型的能力。然而,现有的LLM蒸馏方法过度依赖学生模型生成的输出,这可能会引入生成错误并误导蒸馏过程。此外,由于LLM输出的复杂分布,以往的蒸馏损失函数难以对齐信息量最大的部分。为了解决这些问题,我们提出了一种用于LLM蒸馏的多粒度语义修正方法。在序列层面,我们提出了一种序列校正和重生成(SCRG)策略。SCRG首先计算教师和学生之间的语义认知差异以检测错误token,然后用教师生成的token校正它,并重新生成序列以减少生成错误并增强生成多样性。在token层面,我们设计了一种分布自适应裁剪Kullback-Leibler(DAC-KL)损失作为蒸馏目标函数。DAC-KL损失利用一个可学习的子网络自适应地从教师的输出中提取语义密集区域,避免了冗余信息在蒸馏过程中的干扰。最后,在span层面,我们利用序列的span先验来计算span内的概率相关性,并约束教师和学生的概率相关性一致,进一步增强语义信息的传递。在参数范围从0.1B到13B的不同模型系列上进行的大量实验表明,与现有方法相比,我们的方法具有优越性。

🔬 方法详解

问题定义:现有的大语言模型蒸馏方法主要依赖于学生模型自身的生成结果,这导致学生模型在生成过程中产生的错误会反过来影响蒸馏效果,造成误差累积。此外,由于大语言模型输出分布的复杂性,传统的蒸馏损失函数难以有效地对齐教师模型和学生模型之间最重要的语义信息,导致知识迁移效率低下。

核心思路:本文的核心思路是通过多粒度语义修正来提高蒸馏的准确性和效率。具体来说,从序列、token和span三个粒度入手,分别对学生模型的生成结果进行修正,并设计相应的损失函数来更好地对齐教师模型和学生模型的语义信息。通过减少学生模型的生成错误,并聚焦于教师模型输出中的关键语义区域,从而提升知识蒸馏的效果。

技术框架:该方法包含三个主要模块:序列校正和重生成(SCRG)、分布自适应裁剪KL散度(DAC-KL)损失和span概率相关性约束。SCRG模块用于在序列层面修正学生模型的生成错误。DAC-KL损失用于在token层面聚焦教师模型输出中的关键语义区域。span概率相关性约束用于在span层面保持教师模型和学生模型之间的语义一致性。整体流程是先使用SCRG修正学生模型的生成结果,然后使用DAC-KL损失和span概率相关性约束进行蒸馏训练。

关键创新:该方法最重要的创新点在于提出了多粒度的语义修正策略,从序列、token和span三个不同的粒度对学生模型的生成结果进行修正,从而更全面地提升蒸馏效果。与现有方法相比,该方法不仅关注学生模型的输出,还主动地修正学生模型的错误,并聚焦于教师模型输出中的关键语义区域,从而提高了知识迁移的效率和准确性。

关键设计:SCRG模块通过计算教师模型和学生模型之间的语义认知差异来检测错误token,并使用教师模型生成的token进行修正。DAC-KL损失使用一个可学习的子网络来提取教师模型输出中的语义密集区域,并自适应地调整KL散度的计算权重。span概率相关性约束通过计算span内的概率相关性,并约束教师模型和学生模型的概率相关性一致来实现。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,该方法在不同参数规模(0.1B到13B)的模型上均优于现有方法。例如,在某个具体任务上,使用该方法蒸馏得到的学生模型相比于使用传统方法蒸馏得到的学生模型,性能提升了X个百分点(具体数值论文中给出)。这些结果证明了该方法在提升大语言模型蒸馏效果方面的有效性。

🎯 应用场景

该研究成果可广泛应用于大语言模型的压缩和加速,尤其是在资源受限的场景下,例如移动设备、嵌入式系统等。通过知识蒸馏,可以将大型语言模型的知识迁移到小型模型中,从而在保证性能的同时,降低计算成本和存储空间需求。此外,该方法还可以用于提升对话系统、机器翻译等自然语言处理任务的性能。

📄 摘要(原文)

Knowledge distillation plays a key role in compressing the Large Language Models (LLMs), which boosts a small-size student model under large teacher models' guidance. However, existing LLM distillation methods overly rely on student-generated outputs, which may introduce generation errors and misguide the distillation process. Moreover, the distillation loss functions introduced in previous art struggle to align the most informative part due to the complex distribution of LLMs' outputs. To address these problems, we propose a multi-granularity semantic revision method for LLM distillation. At the sequence level, we propose a sequence correction and re-generation (SCRG) strategy. SCRG first calculates the semantic cognitive difference between the teacher and student to detect the error token, then corrects it with the teacher-generated one, and re-generates the sequence to reduce generation errors and enhance generation diversity. At the token level, we design a distribution adaptive clipping Kullback-Leibler (DAC-KL) loss as the distillation objective function. DAC-KL loss exploits a learnable sub-network to adaptively extract semantically dense areas from the teacher's output, avoiding the interference of redundant information in the distillation process. Finally, at the span level, we leverage the span priors of a sequence to compute the probability correlations within spans, and constrain the teacher and student's probability correlations to be consistent, further enhancing the transfer of semantic information. Extensive experiments across different model families with parameters ranging from 0.1B to 13B demonstrate the superiority of our method compared to existing methods.