Thinking Slow, Fast: Scaling Inference Compute with Distilled Reasoners

📄 arXiv: 2502.20339v1 📥 PDF

作者: Daniele Paliotta, Junxiong Wang, Matteo Pagliardini, Kevin Y. Li, Aviv Bick, J. Zico Kolter, Albert Gu, François Fleuret, Tri Dao

分类: cs.CL, cs.AI

发布日期: 2025-02-27


💡 一句话要点

利用蒸馏推理器扩展计算资源,提升LLM在数学推理任务上的效率与性能。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 大型语言模型 知识蒸馏 Mamba架构 数学推理 推理加速

📋 核心要点

  1. 大型语言模型(LLM)的性能可以通过在测试时扩展计算资源来显著提高,但现有方法通常计算成本高昂。
  2. 论文提出通过知识蒸馏,将Transformer模型的推理能力迁移到更高效的Mamba架构上,从而在相同计算资源下实现更高的推理吞吐量。
  3. 实验表明,蒸馏后的Mamba模型在数学推理任务上,能够在固定时间预算下超越Transformer教师模型,验证了该方法的有效性。

📝 摘要(中文)

本文研究了在固定计算预算下,低复杂度模型是否能通过更高的生成吞吐量超越同等规模的Transformer。为了解决缺乏强大亚二次推理器的问题,作者从预训练的Transformer中蒸馏出纯Mamba和混合Mamba模型。这些模型仅在80亿token上训练,在数学推理数据集上表现出强大的性能和可扩展性,同时在大批量和长序列推理时速度更快。尽管蒸馏导致零样本性能下降,但纯Mamba和混合Mamba模型在固定时间预算下,其覆盖率和准确率均能超过Transformer教师模型,为扩展推理计算开辟了新方向。

🔬 方法详解

问题定义:论文旨在解决大型语言模型在推理时计算成本高昂的问题,尤其是在需要进行多次Chain-of-Thought (CoT) 推理并聚合结果时。现有方法,如直接扩展Transformer模型,虽然可以提高性能,但计算复杂度较高,限制了实际应用。

核心思路:论文的核心思路是通过知识蒸馏,将高性能Transformer模型的推理能力迁移到计算效率更高的Mamba架构上。Mamba架构具有亚二次计算复杂度,因此在处理长序列时具有显著的优势。通过蒸馏,可以在保持甚至提升性能的同时,显著降低推理成本。

技术框架:整体框架包括以下几个步骤:1) 使用预训练的Transformer模型作为教师模型;2) 构建包含数学推理任务的数据集;3) 使用教师模型生成CoT轨迹,作为Mamba模型的训练数据;4) 训练纯Mamba模型和混合Mamba模型(Mamba与Transformer的混合);5) 在数学推理数据集上评估蒸馏模型的性能。

关键创新:最重要的技术创新点在于成功地将Transformer模型的推理能力迁移到Mamba架构上,从而实现了在固定计算预算下更高的推理性能。与直接扩展Transformer模型相比,该方法能够在保证性能的同时,显著降低推理成本。此外,论文还探索了纯Mamba模型和混合Mamba模型,为未来的研究提供了更多选择。

关键设计:论文的关键设计包括:1) 使用80亿token进行训练;2) 采用CoT轨迹作为训练数据,以提高Mamba模型的推理能力;3) 探索了纯Mamba模型和混合Mamba模型,并比较了它们的性能;4) 在数学推理数据集上进行了广泛的实验,以评估蒸馏模型的性能。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,蒸馏后的Mamba模型在数学推理任务上表现出强大的性能和可扩展性。在固定时间预算下,纯Mamba和混合Mamba模型在覆盖率和准确率方面均能超过Transformer教师模型。这表明,通过知识蒸馏,可以将Transformer模型的推理能力迁移到更高效的Mamba架构上,从而在保证性能的同时,显著降低推理成本。

🎯 应用场景

该研究成果可应用于各种需要高效推理的场景,例如数学问题求解、代码生成、自然语言理解等。通过降低推理成本,可以使大型语言模型更易于部署在资源受限的环境中,并促进其在实际应用中的普及。此外,该方法还可以用于训练其他高效的推理器,从而进一步提升人工智能系统的性能。

📄 摘要(原文)

Recent advancements have demonstrated that the performance of large language models (LLMs) can be significantly enhanced by scaling computational resources at test time. A common strategy involves generating multiple Chain-of-Thought (CoT) trajectories and aggregating their outputs through various selection mechanisms. This raises a fundamental question: can models with lower complexity leverage their superior generation throughput to outperform similarly sized Transformers for a fixed computational budget? To address this question and overcome the lack of strong subquadratic reasoners, we distill pure and hybrid Mamba models from pretrained Transformers. Trained on only 8 billion tokens, our distilled models show strong performance and scaling on mathematical reasoning datasets while being much faster at inference for large batches and long sequences. Despite the zero-shot performance hit due to distillation, both pure and hybrid Mamba models can scale their coverage and accuracy performance past their Transformer teacher models under fixed time budgets, opening a new direction for scaling inference compute.