One-Pass to Reason: Token Duplication and Block-Sparse Mask for Efficient Fine-Tuning on Multi-Turn Reasoning
作者: Ritesh Goru, Shanay Mehta, Prateek Jain
分类: cs.CL, cs.AI, cs.LG
发布日期: 2025-04-25 (更新: 2025-07-11)
备注: 9 pages, 3 figures
🔗 代码/项目: GITHUB
💡 一句话要点
提出一种基于Token复制和块稀疏掩码的单次推理微调方法,加速多轮对话LLM训练。
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 多轮对话 大语言模型 微调 注意力机制 Token复制 块稀疏掩码 单次推理 训练加速
📋 核心要点
- 传统多轮推理LLM微调需多次前向传播,效率低下,因为每一轮的推理token在后续轮次中会被丢弃。
- 该论文提出复制响应token并设计块稀疏注意力掩码,实现单次前向传播,降低时间复杂度。
- 实验表明,该方法在保持准确率的同时,显著加速了训练过程,提升了训练效率。
📝 摘要(中文)
在大语言模型(LLM)上对多轮推理数据集进行微调时,由于推理token的可见性约束,每个对话需要N(轮数)次单独的前向传播,因为每一轮的推理token在后续轮次中会被丢弃。我们提出复制响应token,并结合自定义的注意力掩码,以实现对整个对话的单次处理。我们证明了我们的方法在transformer模型中产生与N次传播方法相同的损失,同时将时间复杂度从$Oigl(N^{3}igl)$降低到$Oigl(N^{2}igl)$,并保持相同的内存复杂度。我们的方法在保持准确性的同时,实现了显著的训练加速。我们的实现已在线提供(https://github.com/devrev/One-Pass-to-Reason)。
🔬 方法详解
问题定义:论文旨在解决多轮对话场景下,对大型语言模型(LLM)进行微调时效率低下的问题。传统方法需要对每个对话轮次进行单独的前向传播,因为每一轮的推理token在后续轮次中会被丢弃,导致计算冗余和时间复杂度高。现有方法的痛点在于无法充分利用整个对话的上下文信息,导致训练效率低下。
核心思路:论文的核心思路是通过复制响应token,并结合自定义的块稀疏注意力掩码,将整个多轮对话压缩成一个单一的输入序列,从而实现单次前向传播。这样可以避免重复计算,并充分利用整个对话的上下文信息,从而提高训练效率。
技术框架:整体框架包括以下几个主要步骤:1)将多轮对话的每一轮的输入和响应token进行拼接。2)复制响应token,并将其添加到输入序列中。3)构建块稀疏注意力掩码,控制token之间的可见性,确保每一轮的推理token只对当前轮次和之前的轮次可见。4)将处理后的输入序列输入到Transformer模型中进行训练。
关键创新:最重要的技术创新点在于提出了token复制和块稀疏注意力掩码相结合的方法。与现有方法相比,该方法能够将多轮对话压缩成一个单一的输入序列,从而实现单次前向传播,显著降低了时间复杂度。本质区别在于,传统方法需要多次前向传播,而该方法只需要一次。
关键设计:关键设计包括:1)响应token的复制策略,确保每个响应token都被复制到其对应的轮次中。2)块稀疏注意力掩码的设计,确保每一轮的推理token只对当前轮次和之前的轮次可见,避免信息泄露。3)损失函数与标准LLM微调相同,以保证模型性能。
🖼️ 关键图片
📊 实验亮点
论文提出的方法在保持准确率的同时,显著加速了多轮对话LLM的微调过程。具体来说,该方法将时间复杂度从$Oigl(N^{3}igl)$降低到$Oigl(N^{2}igl)$,同时保持了相同的内存复杂度。实验结果表明,该方法能够实现显著的训练加速,例如在某些数据集上可以达到2倍以上的加速效果。
🎯 应用场景
该研究成果可广泛应用于各种需要多轮对话推理的场景,例如智能客服、对话式问答系统、任务型对话系统等。通过提高LLM在多轮对话场景下的训练效率,可以加速相关应用的开发和部署,提升用户体验,并降低计算成本。未来,该方法可以进一步扩展到更复杂的对话场景,例如涉及外部知识库的对话、多模态对话等。
📄 摘要(原文)
Fine-tuning Large Language Models (LLMs) on multi-turn reasoning datasets requires N (number of turns) separate forward passes per conversation due to reasoning token visibility constraints, as reasoning tokens for a turn are discarded in subsequent turns. We propose duplicating response tokens along with a custom attention mask to enable single-pass processing of entire conversations. We prove our method produces identical losses to the N-pass approach while reducing time complexity from $O\bigl(N^{3}\bigl)$ to $O\bigl(N^{2}\bigl)$ and maintaining the same memory complexity for a transformer based model. Our approach achieves significant training speedup while preserving accuracy. Our implementation is available online (https://github.com/devrev/One-Pass-to-Reason).