LongReD: Mitigating Short-Text Degradation of Long-Context Large Language Models via Restoration Distillation
作者: Zican Dong, Junyi Li, Jinhao Jiang, Mingyu Xu, Wayne Xin Zhao, Bingning Wang, Weipeng Chen
分类: cs.CL, cs.LG
发布日期: 2025-02-11 (更新: 2025-05-28)
备注: ACL2025 Main
🔗 代码/项目: GITHUB
💡 一句话要点
LongReD:通过恢复蒸馏缓解长文本大模型短文本性能退化
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 长文本建模 短文本退化 蒸馏学习 持续预训练 知识迁移
📋 核心要点
- 现有长文本大模型在扩展上下文窗口后,短文本任务性能出现退化,原因在于隐藏状态分布漂移和灾难性遗忘。
- LongReD通过恢复蒸馏,最小化扩展模型和原始模型在短文本上的分布差异,从而缓解短文本性能下降。
- 实验结果表明,LongReD在保持或提升长文本处理能力的同时,有效保留了模型的短文本性能。
📝 摘要(中文)
大型语言模型(LLMs)通过扩展位置编码和轻量级的持续预训练获得了更长的上下文窗口。然而,这通常会导致在短文本任务上的性能下降,但这种下降的原因尚未得到充分探索。本文确定了导致此问题两个主要因素:隐藏状态和注意力分数中的分布漂移,以及持续预训练期间的灾难性遗忘。为了应对这些挑战,我们提出了一种名为LongReD(Long Context Pre-training with Restoration Distillation)的新方法,旨在通过最小化扩展模型和原始模型之间的分布差异来缓解短文本性能下降。除了在长文本上进行训练外,LongReD还从原始模型中提取选定层的隐藏状态,并在短文本上进行蒸馏。此外,LongReD还引入了短到长的蒸馏,通过利用跳过的位置索引,使短文本上的输出分布与长文本上的输出分布对齐。在常见文本基准上的实验表明,LongReD有效地保留了模型的短文本性能,同时保持了与基线相当甚至更好的处理长文本的能力。
🔬 方法详解
问题定义:现有的大型语言模型在扩展上下文窗口以处理更长的文本时,往往会在短文本任务上表现出性能下降。这种现象表明,模型在适应长文本的过程中,丢失了对短文本的有效处理能力。现有的方法未能充分解决这种短文本性能退化问题,需要一种新的训练策略来同时保持长文本处理能力和短文本性能。
核心思路:LongReD的核心思路是通过蒸馏学习,将原始模型在短文本上的知识迁移到扩展后的模型中,从而缓解短文本性能的退化。具体来说,LongReD通过最小化扩展模型和原始模型在短文本上的隐藏状态分布差异,以及输出分布差异,来恢复模型对短文本的理解能力。这种方法旨在弥合长文本预训练带来的分布漂移,并减轻灾难性遗忘的影响。
技术框架:LongReD的整体框架包括两个主要的蒸馏阶段:隐藏状态蒸馏和输出分布蒸馏。首先,在短文本上,LongReD从原始模型中提取选定层的隐藏状态,并将其作为目标,训练扩展后的模型,使其隐藏状态与原始模型对齐。其次,LongReD引入了短到长的蒸馏,通过利用跳过的位置索引,使扩展后的模型在短文本上的输出分布与原始模型在长文本上的输出分布对齐。这两个阶段共同作用,以恢复模型对短文本的理解能力。
关键创新:LongReD的关键创新在于其恢复蒸馏策略,该策略通过最小化隐藏状态和输出分布的差异,有效地缓解了长文本预训练带来的短文本性能退化问题。与传统的蒸馏方法不同,LongReD不仅关注输出分布的对齐,还关注隐藏状态的对齐,从而更全面地恢复模型对短文本的理解能力。此外,短到长的蒸馏策略也是一个创新点,它通过利用跳过的位置索引,实现了短文本和长文本之间的知识迁移。
关键设计:LongReD的关键设计包括以下几个方面:1) 隐藏状态蒸馏中,选择哪些层进行蒸馏是一个重要的参数,需要根据具体任务进行调整。2) 短到长的蒸馏中,如何选择跳过的位置索引也是一个关键的设计,需要根据长文本的长度和短文本的长度进行调整。3) 损失函数的设计也至关重要,需要平衡隐藏状态蒸馏损失和输出分布蒸馏损失,以达到最佳的性能。
🖼️ 关键图片
📊 实验亮点
实验结果表明,LongReD在多个文本基准测试中表现出色,在保持甚至提升长文本处理能力的同时,有效保留了模型的短文本性能。具体来说,LongReD在某些短文本任务上的性能甚至超过了原始模型,证明了其恢复蒸馏策略的有效性。与基线模型相比,LongReD在短文本任务上的性能提升显著,证明了其在缓解短文本性能退化方面的优势。
🎯 应用场景
LongReD具有广泛的应用前景,可以应用于各种需要处理长文本和短文本的自然语言处理任务,例如文档摘要、问答系统、文本分类等。该方法可以提高模型在处理混合长度文本时的性能,从而提升用户体验和应用效果。未来,LongReD可以进一步扩展到其他领域,例如图像处理和语音识别,以解决类似的长短序列处理问题。
📄 摘要(原文)
Large language models (LLMs) have gained extended context windows through scaling positional encodings and lightweight continual pre-training. However, this often leads to degraded performance on short-text tasks, while the reasons for this degradation remain insufficiently explored. In this work, we identify two primary factors contributing to this issue: distribution drift in hidden states and attention scores, and catastrophic forgetting during continual pre-training. To address these challenges, we propose Long Context Pre-training with Restoration Distillation (LongReD), a novel approach designed to mitigate short-text performance degradation through minimizing the distribution discrepancy between the extended and original models. Besides training on long texts, LongReD distills the hidden state of selected layers from the original model on short texts. Additionally, LongReD also introduces a short-to-long distillation, aligning the output distribution on short texts with that on long texts by leveraging skipped positional indices. Experiments on common text benchmarks demonstrate that LongReD effectively preserves the model's short-text performance while maintaining comparable or even better capacity to handle long texts than baselines. Our code is available at https://github.com/RUCAIBox/LongReD.