SplitReason: Learning To Offload Reasoning

📄 arXiv: 2504.16379v1 📥 PDF

作者: Yash Akhauri, Anthony Fei, Chi-Chih Chang, Ahmed F. AbouElhamayed, Yueying Li, Mohamed S. Abdelfattah

分类: cs.CL

发布日期: 2025-04-23


💡 一句话要点

SplitReason:通过学习卸载推理任务提升大语言模型效率与精度。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 大语言模型 推理卸载 链式思维 强化学习 模型微调

📋 核心要点

  1. 大语言模型推理过程计算成本高昂,但并非所有步骤难度一致,存在优化空间。
  2. 提出SplitReason方法,训练小模型识别并卸载困难推理步骤给大模型,实现效率与精度的平衡。
  3. 实验表明,该方法在少量token卸载的情况下,显著提升了数学推理任务的准确率。

📝 摘要(中文)

大型语言模型(LLM)中的推理过程通常比简单的语言建模任务产生更长的token序列。这种扩展的生成长度反映了推理的多步骤和组合性质,并且通常与更高的解决方案准确性相关。从效率的角度来看,更长的token生成加剧了LLM固有的顺序和内存受限的解码阶段。然而,并非所有昂贵的推理过程都同样难以生成。我们利用这一观察结果,仅将推理过程中最具挑战性的部分卸载到更大、更强大的模型,同时使用更小、更高效的模型执行大部分生成;此外,我们训练较小的模型来识别这些困难的部分,并在需要时独立触发卸载。为了实现这种行为,我们从OpenR1-Math-220k链式思维(CoT)数据集中标注了18k推理轨迹中的困难部分。然后,我们将监督微调(SFT)和强化学习微调(RLFT)应用于一个15亿参数的推理模型,训练它学习将自身推理过程中最具挑战性的部分卸载到更大的模型。这种方法分别在卸载1.35%和5%的生成token的情况下,将AIME24推理准确率提高了24%和28.3%。我们开源了我们的SplitReason模型、数据、代码和日志。

🔬 方法详解

问题定义:现有的大语言模型在进行复杂推理任务时,需要生成很长的token序列,这导致计算成本很高,尤其是在解码阶段。虽然增加模型规模可以提高推理精度,但也会进一步增加计算负担。现有的方法没有充分利用推理过程中不同步骤的难度差异,导致资源浪费。

核心思路:SplitReason的核心思想是将推理过程分解为简单和困难两部分,并让一个小模型学习识别哪些部分是困难的,然后将这些困难的部分卸载给一个更大的模型来处理。这样,大部分简单的推理步骤仍然由小模型高效地完成,只有最需要算力的部分才交给大模型,从而在精度和效率之间取得平衡。

技术框架:SplitReason包含以下主要阶段:1) 数据标注:在推理轨迹中人工标注出困难的推理步骤。2) 监督微调(SFT):使用标注数据微调小模型,使其能够识别并标记需要卸载的token。3) 强化学习微调(RLFT):使用强化学习进一步优化小模型的卸载策略,使其在精度和效率之间达到最佳平衡。整体架构是小模型负责大部分token生成,当遇到困难步骤时,触发卸载机制,将该部分交给大模型处理,最后将大模型的输出合并到小模型的生成序列中。

关键创新:SplitReason最重要的创新在于让小模型具备了自主学习和判断推理难度的能力,并能够根据自身能力动态地将任务分配给更大的模型。这种自适应的卸载机制能够更有效地利用计算资源,避免了对所有推理步骤都使用大模型进行处理的浪费。

关键设计:在数据标注方面,需要仔细选择标注策略,以确保标注的质量和一致性。在SFT阶段,需要设计合适的损失函数来训练小模型识别困难步骤。在RLFT阶段,需要定义合适的奖励函数,以鼓励小模型在精度和效率之间做出合理的权衡。具体的网络结构细节和参数设置在论文中未详细说明,属于未知信息。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,SplitReason方法在AIME24数学推理任务上取得了显著的提升。在仅卸载1.35%的token的情况下,准确率提高了24%;在卸载5%的token的情况下,准确率提高了28.3%。这表明该方法能够有效地识别并卸载困难的推理步骤,从而在精度和效率之间取得良好的平衡。

🎯 应用场景

SplitReason方法具有广泛的应用前景,可以应用于各种需要复杂推理的场景,例如数学问题求解、代码生成、知识图谱推理等。通过将推理任务分解并分配给不同规模的模型,可以显著降低计算成本,提高推理效率,并促进大语言模型在资源受限环境中的部署。

📄 摘要(原文)

Reasoning in large language models (LLMs) tends to produce substantially longer token generation sequences than simpler language modeling tasks. This extended generation length reflects the multi-step, compositional nature of reasoning and is often correlated with higher solution accuracy. From an efficiency perspective, longer token generation exacerbates the inherently sequential and memory-bound decoding phase of LLMs. However, not all parts of this expensive reasoning process are equally difficult to generate. We leverage this observation by offloading only the most challenging parts of the reasoning process to a larger, more capable model, while performing most of the generation with a smaller, more efficient model; furthermore, we teach the smaller model to identify these difficult segments and independently trigger offloading when needed. To enable this behavior, we annotate difficult segments across 18k reasoning traces from the OpenR1-Math-220k chain-of-thought (CoT) dataset. We then apply supervised fine-tuning (SFT) and reinforcement learning fine-tuning (RLFT) to a 1.5B-parameter reasoning model, training it to learn to offload the most challenging parts of its own reasoning process to a larger model. This approach improves AIME24 reasoning accuracy by 24% and 28.3% while offloading 1.35% and 5% of the generated tokens respectively. We open-source our SplitReason model, data, code and logs.