Beyond Imitation: Learning Key Reasoning Steps from Dual Chain-of-Thoughts in Reasoning Distillation
作者: Chengwei Dai, Kun Li, Wei Zhou, Songlin Hu
分类: cs.CL, cs.AI
发布日期: 2024-05-30
🔗 代码/项目: GITHUB
💡 一句话要点
提出EDIT方法,通过错误驱动的关键推理步骤蒸馏提升小模型推理能力。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 语言模型蒸馏 链式思考 关键推理步骤 错误驱动学习 最小编辑距离
📋 核心要点
- 现有蒸馏方法难以让小模型学习CoT推理中的关键步骤,导致其推理能力受限。
- EDIT方法通过分析双重CoT数据中的错误,定位关键推理步骤并进行针对性学习。
- 实验表明,EDIT方法能有效提升小模型在领域内和领域外数据集上的推理性能。
📝 摘要(中文)
随着大型语言模型(LLMs)规模的扩大和链式思考(CoTs)推理能力的增强,实际资源约束推动了将这些能力提炼到更紧凑的小型语言模型(SLMs)中的努力。我们发现CoTs主要由简单的推理形式组成,其中只有一小部分(约4.7%)的关键推理步骤真正影响结论。然而,以往的蒸馏方法通常只在教师LLM生成的正确CoTs数据上对学生SLM进行监督微调,导致学生难以学习关键推理步骤,而是模仿教师的推理形式,并在这些步骤上出错或遗漏。为了解决这些问题,我们借鉴人类学习的经验,即根据正确的解决方案分析错误通常会揭示导致成功或失败的关键步骤,我们提出了一种mistakE-Driven key reasonIng step distillaTion (EDIT)的新方法,该方法进一步帮助SLM学习关键推理步骤,而不是仅仅进行简单的微调。首先,为了揭示CoTs中的这些关键步骤,我们设计了特定的提示来生成具有相似推理路径但结论不同的双重CoTs数据。然后,我们对双重CoTs数据应用最小编辑距离算法来定位这些关键步骤,并优化这些步骤的可能性。大量的实验验证了EDIT在领域内和领域外基准推理数据集上的有效性。进一步的分析表明,EDIT可以生成具有更多正确关键推理步骤的高质量CoTs。值得注意的是,我们还探讨了不同的错误模式如何影响性能,并发现EDIT从双重CoTs中的逻辑错误中获益更多,而不是从知识或数学计算错误中获益。
🔬 方法详解
问题定义:现有的大语言模型蒸馏方法,特别是基于链式思考(CoT)的蒸馏,通常只是简单地使用大型语言模型生成的正确推理过程来微调小型语言模型。这种方法忽略了CoT中关键推理步骤的重要性,导致小型模型只是模仿大型模型的推理形式,而无法真正理解和掌握关键的推理逻辑,从而在遇到复杂问题时容易出错。
核心思路:论文的核心思路是借鉴人类从错误中学习的经验。通过分析相似但结论不同的双重CoT数据,可以更容易地识别出导致推理失败的关键步骤。然后,通过优化这些关键步骤,可以更有效地提升小型模型的推理能力。
技术框架:EDIT方法主要包含以下几个阶段: 1. 双重CoT数据生成:设计特定的提示,引导大型语言模型生成具有相似推理路径但最终结论不同的两组CoT数据。 2. 关键步骤定位:使用最小编辑距离算法,比较双重CoT数据,找出导致结论差异的关键推理步骤。 3. 关键步骤优化:通过优化关键推理步骤的似然性,使小型模型能够更准确地学习和掌握这些步骤。
关键创新:EDIT方法的关键创新在于其错误驱动的学习方式。与传统的监督微调方法不同,EDIT方法不是简单地模仿大型模型的推理过程,而是通过分析错误来定位关键步骤,并针对性地进行学习。这种方法能够更有效地提升小型模型的推理能力,使其能够更好地泛化到新的问题上。
关键设计: * 双重CoT生成提示:提示的设计需要保证生成的CoT数据具有相似的推理路径,但最终结论不同,以便于后续的关键步骤定位。 * 最小编辑距离算法:选择合适的编辑距离算法,并设置合理的参数,以准确地识别出关键推理步骤。 * 损失函数:设计合适的损失函数,以优化关键推理步骤的似然性。例如,可以使用交叉熵损失函数,鼓励模型生成正确的关键推理步骤。
🖼️ 关键图片
📊 实验亮点
实验结果表明,EDIT方法在多个基准推理数据集上都取得了显著的性能提升。例如,在某些数据集上,EDIT方法可以将小型模型的准确率提升超过5%。此外,分析表明,EDIT方法生成的CoT数据具有更高的质量,包含更多正确的关键推理步骤。研究还发现,EDIT方法从逻辑错误中获益更多,而不是从知识或数学计算错误中获益。
🎯 应用场景
EDIT方法可应用于各种需要复杂推理能力的场景,例如问答系统、数学问题求解、代码生成等。通过将大型语言模型的推理能力蒸馏到小型模型中,可以在资源受限的环境下部署高性能的推理系统,例如移动设备或嵌入式系统。此外,该方法还可以用于提升教育领域的智能化水平,例如自动批改作业、个性化辅导等。
📄 摘要(原文)
As Large Language Models (LLMs) scale up and gain powerful Chain-of-Thoughts (CoTs) reasoning abilities, practical resource constraints drive efforts to distill these capabilities into more compact Smaller Language Models (SLMs). We find that CoTs consist mainly of simple reasoning forms, with a small proportion ($\approx 4.7\%$) of key reasoning steps that truly impact conclusions. However, previous distillation methods typically involve supervised fine-tuning student SLMs only on correct CoTs data produced by teacher LLMs, resulting in students struggling to learn the key reasoning steps, instead imitating the teacher's reasoning forms and making errors or omissions on these steps. To address these issues, drawing an analogy to human learning, where analyzing mistakes according to correct solutions often reveals the crucial steps leading to successes or failures, we propose mistak\textbf{E}-\textbf{D}riven key reason\textbf{I}ng step distilla\textbf{T}ion (\textbf{EDIT}), a novel method that further aids SLMs learning key reasoning steps rather than mere simple fine-tuning. Firstly, to expose these crucial steps in CoTs, we design specific prompts to generate dual CoTs data with similar reasoning paths but divergent conclusions. Then, we apply the minimum edit distance algorithm on the dual CoTs data to locate these key steps and optimize the likelihood of these steps. Extensive experiments validate the effectiveness of EDIT across both in-domain and out-of-domain benchmark reasoning datasets. Further analysis shows that EDIT can generate high-quality CoTs with more correct key reasoning steps. Notably, we also explore how different mistake patterns affect performance and find that EDIT benefits more from logical errors than from knowledge or mathematical calculation errors in dual CoTs\footnote{Code can be found at \url{https://github.com/C-W-D/EDIT}}.