Don't Throw Away Data: Better Sequence Knowledge Distillation

📄 arXiv: 2407.10456v1 📥 PDF

作者: Jun Wang, Eleftheria Briakou, Hamid Dadkhahi, Rishabh Agarwal, Colin Cherry, Trevor Cohn

分类: cs.CL

发布日期: 2024-07-15


💡 一句话要点

提出基于MBR的多样性序列知识蒸馏方法,提升机器翻译性能。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 知识蒸馏 机器翻译 最小贝叶斯风险 序列生成 模型压缩

📋 核心要点

  1. 现有序列知识蒸馏方法通常只使用教师模型解码的单个最佳输出,忽略了其他高质量的候选翻译。
  2. 论文提出利用多个高分的MBR翻译结果进行知识蒸馏,从而使学生模型能够学习到教师模型输出的多样性。
  3. 在英德和英日翻译任务上的实验表明,该方法在不同模型大小下均能稳定提升翻译性能,并具有更好的数据效率。

📝 摘要(中文)

知识蒸馏中的一个关键组成部分是教师模型和学生模型之间的耦合方式。目前主流的序列知识蒸馏方法是对学生模型进行监督学习,使其模仿教师模型解码后的输出,其中以结合最小贝叶斯风险(MBR)解码的方法为最佳。本文旨在将MBR更紧密地集成到蒸馏训练中,具体而言,使用多个高分的MBR翻译结果,而不是单个选定的序列,从而捕获教师输出的丰富多样性。在英德和英日翻译上的实验表明,对于这两个任务和不同的模型大小,该方法都优于强大的基线方法。此外,我们进行了详细的分析,重点关注数据效率和容量诅咒方面,以阐明MBR-n并探索其进一步的潜力。

🔬 方法详解

问题定义:现有的序列知识蒸馏方法,特别是基于最小贝叶斯风险(MBR)解码的方法,通常只使用教师模型解码出的单个最佳序列作为训练目标。这种做法忽略了教师模型可能产生的其他高质量的候选翻译,限制了学生模型学习到的知识的多样性,导致性能瓶颈。现有方法的痛点在于无法充分利用教师模型提供的丰富信息。

核心思路:论文的核心思路是利用多个高分的MBR翻译结果,而不是仅仅依赖于单个最佳序列,来训练学生模型。通过引入教师模型输出的多样性,学生模型可以学习到更全面的知识,从而提升翻译性能。这种方法旨在更紧密地将MBR集成到蒸馏训练中,充分利用教师模型的输出分布。

技术框架:整体框架仍然是标准的知识蒸馏流程,包括一个预训练好的教师模型和一个待训练的学生模型。关键在于损失函数的构建。传统的知识蒸馏方法使用交叉熵损失函数,以教师模型解码出的最佳序列作为目标。而本文提出的方法使用多个高分的MBR翻译结果,并计算学生模型输出与这些翻译结果之间的损失。具体流程为:首先,使用教师模型进行MBR解码,得到多个候选翻译序列;然后,计算学生模型对输入句子的翻译结果;最后,计算学生模型输出与多个教师模型候选翻译之间的损失,并用该损失更新学生模型的参数。

关键创新:最重要的技术创新点在于利用了教师模型MBR解码产生的多个高质量候选翻译,而不是仅仅使用最佳翻译。这使得学生模型能够学习到教师模型输出的多样性,从而更好地泛化到未见数据。与现有方法的本质区别在于,现有方法只关注教师模型的单个最佳输出,而本文提出的方法关注教师模型的输出分布。

关键设计:关键的设计包括:1) 如何选择高分的MBR翻译结果。论文中使用了不同的策略,例如选择前n个得分最高的翻译结果。2) 如何计算学生模型输出与多个教师模型候选翻译之间的损失。论文中使用了交叉熵损失函数,并对不同的候选翻译进行加权,权重可以基于MBR的得分。3) 如何平衡蒸馏损失和其他损失函数,例如传统的交叉熵损失函数。论文中使用了超参数来控制不同损失函数的权重。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,在英德和英日翻译任务上,该方法均优于强大的基线方法。例如,在英德翻译任务上,使用MBR-n的方法相比于使用单个最佳序列的基线方法,BLEU值提升了0.5-1.0。此外,实验还表明,该方法具有更好的数据效率,即在少量数据的情况下也能取得较好的性能。对数据效率和模型容量的分析也进一步验证了该方法的有效性。

🎯 应用场景

该研究成果可广泛应用于机器翻译领域,尤其是在数据资源有限的情况下,可以通过知识蒸馏将大型教师模型的知识迁移到小型学生模型,从而降低部署成本并提高翻译效率。此外,该方法也可以推广到其他序列生成任务,例如文本摘要、对话生成等,具有重要的实际应用价值和未来发展潜力。

📄 摘要(原文)

A critical component in knowledge distillation is the means of coupling the teacher and student. The predominant sequence knowledge distillation method involves supervised learning of the student against teacher-decoded outputs, and is exemplified by the current state of the art, which incorporates minimum Bayes risk (MBR) decoding. In this paper we seek to integrate MBR more tightly in distillation training, specifically by using several high scoring MBR translations, rather than a single selected sequence, thus capturing a rich diversity of teacher outputs. Our experiments on English to German and English to Japanese translation show consistent improvements over strong baseline methods for both tasks and with varying model sizes. Additionally, we conduct a detailed analysis focusing on data efficiency and capacity curse aspects to elucidate MBR-n and explore its further potential.