Training-Inference Consistent Segmented Execution for Long-Context LLMs
作者: Xianpeng Shang, Jiang Li, Zehua Duo, Qianyi Cai, Xiangdong Su
分类: cs.CL, cs.LG
发布日期: 2026-05-12
备注: Accepted by ICML 2026. 19 pages, 6 figures, 3 tables
💡 一句话要点
提出训练-推理一致的分段执行框架,提升长文本LLM的效率和可扩展性
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 长文本LLM 分段执行 训练推理一致性 高效推理 KV缓存
📋 核心要点
- 现有长文本LLM方法在推理时采用分段执行以提高效率,但训练时仍使用全上下文注意力,导致训练和推理不一致。
- 论文提出训练-推理一致的分段执行框架,训练和推理采用相同的分段前向执行语义,保证状态转移的一致性。
- 实验结果表明,该方法在长文本基准测试中性能与全上下文注意力相当,同时显著降低了内存占用,提高了可扩展性。
📝 摘要(中文)
基于Transformer的大语言模型在长文本生成中面临严重的可扩展性挑战,这主要是由于全上下文注意力机制带来的计算和内存成本。为了在实际的计算和内存约束下提高效率,许多推理高效的长文本方法仅在推理阶段采用有界上下文或分段执行,而在训练阶段仍然采用全上下文注意力机制,导致训练和推理执行以及状态转移语义不匹配。基于此,我们提出了一个训练-推理一致的分段生成框架,其中训练和推理遵循相同的分段前向执行语义。在训练过程中,通过限制梯度传播到从紧邻的前一个段继承的KV状态来强制与推理保持一致,同时允许在不涉及梯度传播的情况下,在正向传播期间进行特定头的过去KV状态访问。在长文本基准测试中,我们的方法实现了与全上下文注意力相当的性能,同时在与强大的推理高效基线相比实现了有竞争力的延迟-内存权衡,并且在非常长的上下文长度下显著提高了可扩展性(例如,在128K上下文长度下,峰值预填充内存降低了约6倍,与使用FlashAttention的全上下文注意力相比)。
🔬 方法详解
问题定义:现有长文本LLM在处理长序列时,由于Transformer的全局注意力机制,计算和内存需求呈平方级增长。为了解决这个问题,许多方法在推理阶段采用分段或有界上下文注意力,但在训练阶段仍然使用全局注意力。这种训练和推理的不一致性导致模型性能下降,尤其是在长文本生成任务中。
核心思路:论文的核心思路是使训练和推理过程保持一致。具体来说,就是在训练阶段也采用分段执行的方式,使得模型在训练时就学习到如何在分段的上下文信息下进行预测,从而避免了训练和推理之间的gap。
技术框架:该框架的核心是分段执行。在训练和推理阶段,输入文本被分成多个段。模型逐段处理输入,并维护一个KV缓存,用于存储之前段的信息。在每个段的处理过程中,模型可以访问当前段和之前段的KV缓存。为了保证训练和推理的一致性,梯度传播被限制在当前段和紧邻的前一个段的KV状态上。
关键创新:该方法最重要的创新点在于训练和推理的一致性。通过在训练阶段引入分段执行,并限制梯度传播,使得模型能够更好地适应推理阶段的分段上下文信息,从而提高了长文本生成任务的性能。此外,该方法允许在正向传播期间进行特定头的过去KV状态访问,进一步提升了模型性能。
关键设计:在训练过程中,损失函数采用标准的交叉熵损失。关键在于梯度传播的限制,只允许梯度传播到当前段和紧邻的前一个段的KV状态。此外,论文还探索了不同的段长度和KV缓存大小对模型性能的影响。具体参数设置在论文中有详细描述。
🖼️ 关键图片
📊 实验亮点
实验结果表明,该方法在长文本基准测试中取得了与全上下文注意力相当的性能,同时显著降低了内存占用。例如,在128K上下文长度下,峰值预填充内存降低了约6倍,与使用FlashAttention的全上下文注意力相比。此外,该方法在延迟-内存权衡方面也优于其他推理高效的基线方法。
🎯 应用场景
该研究成果可应用于各种需要处理长文本的场景,例如长篇小说生成、法律文档分析、科学论文摘要等。通过降低内存占用和提高计算效率,该方法使得长文本LLM能够部署在资源受限的设备上,并加速长文本生成任务的推理速度。未来,该方法可以进一步扩展到处理更长的文本序列,并与其他长文本处理技术相结合,以实现更好的性能。
📄 摘要(原文)
Transformer-based large language models face severe scalability challenges in long-context generation due to the computational and memory costs of full-context attention. Under practical computation and memory constraints, many inference-efficient long-context methods improve efficiency by adopting bounded-context or segment-level execution only during inference, while continuing to train models under full-context attention, resulting in a mismatch between training and inference execution and state-transition semantics. Based on this insight, we propose a training-inference consistent segment-level generation framework, in which training and inference follow the same segment-level forward execution semantics. During training, consistency with inference is enforced by restricting gradient propagation to KV states carried over from the immediately preceding segment, while permitting head-specific access to past KV states during the forward pass without involving them in gradient propagation. Across long-context benchmarks, our approach achieves performance comparable to full-context attention, while achieving competitive latency-memory trade-offs against strong inference-efficient baselines, and substantially improving scalability at very long context lengths (e.g., approximately 6x lower peak prefill memory at 128K compared to full-context attention with FlashAttention).