Think, Prune, Train, Improve: Scaling Reasoning without Scaling Models

📄 arXiv: 2504.18116v1 📥 PDF

作者: Caia Costello, Simon Guo, Anna Goldie, Azalia Mirhoseini

分类: cs.LG

发布日期: 2025-04-25


💡 一句话要点

提出Think, Prune, Train框架,无需扩大模型即可提升LLM推理能力

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 大型语言模型 推理能力 合成数据 微调 剪枝 自提升 GSM8K

📋 核心要点

  1. 现有LLM推理能力受限于高质量训练数据,合成数据微调受多种因素影响,难以实现有效自提升。
  2. 提出Think, Prune, Train框架,利用模型自身推理轨迹进行微调,并使用ground-truth剪枝保证数据质量。
  3. 实验表明,该方法显著提升模型在GSM8K数据集上的推理性能,甚至超越更大规模的模型。

📝 摘要(中文)

大型语言模型(LLMs)在编程和数学推理任务中表现出强大的能力,但受到高质量训练数据有限的约束。合成数据可以用来增强微调效果,但包括模型大小、合成数据量、剪枝策略和微调轮数等多种因素会影响这一过程。我们探索了这些维度,并研究了哪些条件能够实现模型的自我改进。我们引入了Think, Prune, Train过程,这是一个可扩展的框架,它使用ground-truth剪枝来确保高质量的训练数据,从而在模型自身的推理轨迹上迭代地微调模型。这种方法产生了改进的性能:在GSM8K上,Gemma2-2B的Pass@1达到57.6%(从41.9%提升),Gemma2-9B达到82%,与LLaMA-3.1-70B相匹配,LLaMA-3.1-70B达到91%,甚至超过了GPT-4o,证明了自我生成的推理和系统的数据选择对于提高LLM能力是有效的。

🔬 方法详解

问题定义:论文旨在解决大型语言模型在推理任务中,由于高质量训练数据不足而导致的性能瓶颈问题。现有方法依赖于人工标注或未经筛选的合成数据,难以保证训练数据的质量,从而限制了模型性能的提升。

核心思路:论文的核心思路是利用模型自身生成的推理轨迹作为训练数据,并通过ground-truth剪枝策略,去除推理过程中错误的步骤,从而保证训练数据的质量。这种方法允许模型在高质量的自我生成的数据上进行迭代学习,实现性能的自提升。

技术框架:Think, Prune, Train框架包含三个主要阶段:1) Think:模型生成推理轨迹;2) Prune:使用ground-truth对推理轨迹进行剪枝,去除错误步骤;3) Train:使用剪枝后的高质量推理轨迹对模型进行微调。该过程可以迭代进行,使模型在自我生成的、高质量的数据上不断学习和改进。

关键创新:该方法最重要的创新点在于利用ground-truth剪枝策略来保证训练数据的质量。与传统的合成数据方法不同,该方法能够有效地去除推理过程中的错误步骤,从而避免了模型在错误数据上进行学习,提高了微调的效率和效果。

关键设计:关键设计包括:1) 使用模型自身的推理轨迹作为训练数据,保证了数据与模型的一致性;2) 使用ground-truth剪枝策略,确保训练数据的质量;3) 迭代的Think, Prune, Train过程,允许模型在自我生成的、高质量的数据上不断学习和改进。具体的参数设置和损失函数选择可能根据不同的模型和任务进行调整,论文中未明确说明。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,Think, Prune, Train框架能够显著提升LLM的推理性能。在GSM8K数据集上,Gemma2-2B的Pass@1从41.9%提升至57.6%,Gemma2-9B达到82%,与LLaMA-3.1-70B性能相当,LLaMA-3.1-70B更是达到了91%,甚至超越了GPT-4o。

🎯 应用场景

该研究成果可应用于提升各种LLM在数学、编程等推理任务中的性能。通过自我生成和高质量数据筛选,降低了对大规模人工标注数据的依赖,具有广泛的应用前景,并有望推动LLM在资源受限场景下的应用。

📄 摘要(原文)

Large language models (LLMs) have demonstrated strong capabilities in programming and mathematical reasoning tasks, but are constrained by limited high-quality training data. Synthetic data can be leveraged to enhance fine-tuning outcomes, but several factors influence this process, including model size, synthetic data volume, pruning strategy, and number of fine-tuning rounds. We explore these axes and investigate which conditions enable model self-improvement. We introduce the Think, Prune, Train process, a scalable framework that iteratively fine-tunes models on their own reasoning traces, using ground-truth pruning to ensure high-quality training data. This approach yields improved performance: on GSM8K, Gemma2-2B achieves a Pass@1 of 57.6% (from 41.9%), Gemma2-9B reaches 82%, matching LLaMA-3.1-70B, and LLaMA-3.1-70B attains 91%, even surpassing GPT-4o, demonstrating the effectiveness of self-generated reasoning and systematic data selection for improving LLM capabilities.