Enhancing Generalization in Chain of Thought Reasoning for Smaller Models

📄 arXiv: 2501.09804v1 📥 PDF

作者: Maxwell J. Yin, Dingyi Jiang, Yongbing Chen, Boyu Wang, Charles Ling

分类: cs.LG, cs.AI, cs.CL

发布日期: 2025-01-16


💡 一句话要点

提出PRADA框架,提升小模型在思维链推理中的泛化能力

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 思维链推理 知识蒸馏 领域对抗训练 小模型 泛化能力 提示工程 可解释性

📋 核心要点

  1. 现有CoT知识蒸馏方法在小型LLM中泛化能力不足,存在过度保守记忆的问题。
  2. 提出PRADA框架,通过领域对抗微调和提示工程增强模型在不同CoT领域的适应性。
  3. 实验结果表明,PRADA在多个任务上显著优于现有方法,并提升了模型的可解释性。

📝 摘要(中文)

思维链(CoT)推理在小型语言模型中是一个具有挑战性的自然语言处理问题,但在许多实际应用中非常需要。现有的CoT知识蒸馏方法通常会导致小型LLM中过度保守的记忆,从而导致泛化置信度较低。由于完全保留教师模型的CoT能力是不可能的,我们假设对抗性CoT微调对于开发具有鲁棒CoT泛化能力的小型LLM至关重要。为此,我们提出了PRompt-Assisted Domain-Adversarial fine-tuning(PRADA),这是一个集成了多样化CoT领域的原则性微调框架。具体来说,PRADA在小型LLM中率先实现了两项CoT改进:(1)通过领域对抗微调恢复在蒸馏过程中通常丢失的领域不变特征洞察;(2)通过采用领域对抗方法增强CoT提示工程的领域适应性。我们从理论上证明了我们方法的有效性,并通过实验表明它在各种任务中显著优于最先进的方法。此外,我们的实验结果表明,小型LLM在利用PRADA时,与领域知识紧密结合,从而提高了我们方法的可解释性。

🔬 方法详解

问题定义:论文旨在解决小型语言模型在思维链(CoT)推理中泛化能力不足的问题。现有CoT知识蒸馏方法存在过度保守记忆的痛点,导致模型在面对新领域或任务时表现不佳,无法有效利用CoT进行推理。

核心思路:论文的核心思路是通过对抗性微调来提升小型LLM的CoT泛化能力。具体而言,通过领域对抗训练,使模型能够学习到领域不变的特征,并增强CoT提示工程的领域适应性,从而提高模型在不同领域和任务上的推理能力。

技术框架:PRADA框架包含两个主要组成部分:领域对抗微调和提示工程增强。领域对抗微调旨在恢复在知识蒸馏过程中丢失的领域不变特征,通过最小化领域分类器的准确率,鼓励模型学习与领域无关的特征表示。提示工程增强则通过领域对抗方法,使模型能够更好地适应不同领域的CoT提示。整体流程是先进行领域对抗微调,然后利用增强的提示工程进行推理。

关键创新:PRADA的关键创新在于将领域对抗训练引入到CoT推理的小型LLM微调中。与传统的知识蒸馏方法不同,PRADA不仅关注模仿教师模型的CoT能力,更注重提升模型在不同领域的泛化能力。通过领域对抗训练,模型能够学习到更鲁棒的特征表示,从而更好地适应新的领域和任务。

关键设计:PRADA使用领域分类器来区分不同的CoT领域,并使用梯度反转层(Gradient Reversal Layer)来对抗领域分类器,从而鼓励模型学习领域不变的特征。损失函数包括CoT推理损失、领域分类损失和对抗损失。提示工程增强则通过调整提示模板,使其更具领域适应性。具体的参数设置和网络结构细节在论文中未明确给出,属于未知信息。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,PRADA框架在多个CoT推理任务上显著优于现有方法。具体性能数据和对比基线在摘要中未明确给出,属于未知信息。但论文强调,PRADA能够使小型LLM与领域知识更紧密结合,从而提高模型的可解释性。

🎯 应用场景

该研究成果可应用于各种需要小型语言模型进行复杂推理的场景,例如移动设备上的智能助手、资源受限的边缘计算设备以及需要快速部署的特定领域应用。通过提升小模型的CoT推理能力,可以降低计算成本,提高响应速度,并扩展语言模型在实际应用中的范围。

📄 摘要(原文)

Chain-of-Thought (CoT) reasoning in smaller language models is a challenging natural language process problem yet highly desirable in many real-life applications. Existing CoT knowledge distillation methods often suffer from overly conservative memorization in smaller LLMs, leading to low generalization confidence. As fully preserving the CoT ability of teacher model is impossible, we hypothesize that adversarial CoT fine-tuning is crucial for developing smaller LLM with robust CoT generalization. To this end, we propose \textit{PRompt-Assisted Domain-Adversarial fine-tuning} (PRADA), a principled fine-tuning framework that integrates diverse CoT domains. Specifically, PRADA pioneers two CoT improvements in smaller LLM: (1) Recovering the domain-invariant feature insight which typically lost during distillation with domain adversarial fine-tuning; (2) Enhancing the domain adaptability of CoT prompt engineering by employing domain-adversarial approaches. We theoretically demonstrate the effectiveness of our approach and empirically show that it significantly outperforms the state of the arts in a wide range of tasks. Moreover, our empirical findings reveal that the smaller LLM, when leveraging PRADA, aligns closely with domain knowledge, thereby improving the explainability of our approach.