Investigating Mysteries of CoT-Augmented Distillation

📄 arXiv: 2406.14511v2 📥 PDF

作者: Somin Wadhwa, Silvio Amir, Byron C. Wallace

分类: cs.CL

发布日期: 2024-06-20 (更新: 2024-09-27)

备注: Accepted to EMNLP 2024


💡 一句话要点

探究CoT增强蒸馏的内在机理,发现关键token足以提升学生模型性能

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 思维链 模型蒸馏 语言模型 消融实验 知识迁移

📋 核心要点

  1. 现有模型蒸馏方法缺乏对CoT推理过程的深入理解,限制了学生模型性能的进一步提升。
  2. 该论文通过消融实验,分析CoT序列在模型蒸馏中的作用,揭示其内在机理。
  3. 实验表明,CoT序列的位置和连贯性并非关键,少量关键token即可实现性能提升。

📝 摘要(中文)

本文旨在探究“思维链”(CoT)推理过程如何助力模型蒸馏。CoT通过生成中间推理步骤来提升大型语言模型在问答等任务上的表现。最近的研究表明,将从大型“教师”模型中提取的CoT序列与目标标签一起用于微调小型“学生”模型,可以显著提高学生模型的性能。本文通过消融实验来研究这种额外训练信号的有效性。研究结果表明:(1)将CoT序列置于标签之后比置于标签之前能获得更好的下游性能,这意味着学生模型在测试时无需进行推理即可获得性能提升。(2)当以这种方式附加CoT序列时,它们不必是连贯的推理序列也能产生改进;性能提升对CoT token的排列具有鲁棒性。实际上,(3)少量关键token足以实现与使用完整CoT序列进行模型蒸馏时观察到的性能提升。

🔬 方法详解

问题定义:论文旨在解决模型蒸馏中,如何有效利用大型语言模型(LLM)的思维链(CoT)推理过程,来提升小型学生模型的性能。现有方法通常直接将CoT序列作为额外的训练信号,但缺乏对CoT序列内在作用机制的深入理解,导致学生模型性能提升有限,且计算成本较高。

核心思路:论文的核心思路是通过消融实验,分析CoT序列在模型蒸馏中的作用,从而揭示其内在机理。具体而言,论文研究了CoT序列的位置、连贯性以及token数量对学生模型性能的影响,旨在找到最有效的CoT利用方式。

技术框架:论文的技术框架主要包括以下几个步骤:1) 使用大型教师模型生成CoT序列;2) 构建包含CoT序列和目标标签的训练数据集;3) 使用该数据集微调小型学生模型;4) 通过消融实验,分析CoT序列的不同属性对学生模型性能的影响。消融实验主要包括改变CoT序列的位置(标签前/后)、打乱CoT序列的token顺序、以及减少CoT序列的token数量。

关键创新:论文的关键创新在于发现CoT序列的位置和连贯性并非提升学生模型性能的关键因素,而少量关键token足以实现与使用完整CoT序列相当的性能提升。这一发现颠覆了以往对CoT作用的认知,为模型蒸馏提供了新的思路。

关键设计:论文的关键设计在于消融实验的设计。通过系统地改变CoT序列的属性,并观察学生模型性能的变化,论文能够有效地分析CoT序列的内在作用机制。例如,通过将CoT序列置于标签之后,论文发现学生模型在测试时无需进行推理即可获得性能提升,这表明CoT序列可能主要起到一种“提示”或“正则化”的作用。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,将CoT序列置于标签之后比置于标签之前能获得更好的下游性能。更令人惊讶的是,即使打乱CoT序列的token顺序,学生模型的性能依然能够得到提升。此外,仅使用少量关键token进行CoT增强蒸馏,即可实现与使用完整CoT序列相当的性能提升。这些发现挑战了以往对CoT作用的认知。

🎯 应用场景

该研究成果可应用于各种需要模型蒸馏的场景,例如将大型语言模型部署到资源受限的设备上。通过利用少量关键token进行CoT增强蒸馏,可以显著降低计算成本,提高模型部署效率。此外,该研究还有助于更好地理解思维链推理过程,为开发更有效的模型蒸馏方法提供理论指导。

📄 摘要(原文)

Eliciting "chain of thought" (CoT) rationales -- sequences of token that convey a "reasoning" process -- has been shown to consistently improve LLM performance on tasks like question answering. More recent efforts have shown that such rationales can also be used for model distillation: Including CoT sequences (elicited from a large "teacher" model) in addition to target labels when fine-tuning a small student model yields (often substantial) improvements. In this work we ask: Why and how does this additional training signal help in model distillation? We perform ablations to interrogate this, and report some potentially surprising results. Specifically: (1) Placing CoT sequences after labels (rather than before) realizes consistently better downstream performance -- this means that no student "reasoning" is necessary at test time to realize gains. (2) When rationales are appended in this way, they need not be coherent reasoning sequences to yield improvements; performance increases are robust to permutations of CoT tokens, for example. In fact, (3) a small number of key tokens are sufficient to achieve improvements equivalent to those observed when full rationales are used in model distillation.