Learning Dynamics in Continual Pre-Training for Large Language Models

📄 arXiv: 2505.07796v2 📥 PDF

作者: Xingjin Wang, Howe Tissue, Lu Wang, Linjing Li, Daniel Dajun Zeng

分类: cs.CL, cs.AI, cs.LG

发布日期: 2025-05-12 (更新: 2025-06-19)

备注: Accepted to ICML2025 (Oral)


💡 一句话要点

提出CPT缩放定律,预测大语言模型持续预训练过程中的性能演变。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 持续预训练 大语言模型 缩放定律 学习动态 领域自适应

📋 核心要点

  1. 持续预训练(CPT)是提升大语言模型在特定领域性能的有效方法,但其学习过程的动态变化尚不明确。
  2. 论文通过解耦分布偏移和学习率退火的影响,提出了CPT缩放定律,用于预测训练过程中的损失变化。
  3. 实验证明,该缩放定律在不同的CPT数据集和超参数设置下均有效,可用于定制训练策略。

📝 摘要(中文)

持续预训练(CPT)已成为将强大的基础模型应用于特定下游任务的一种流行且有效的方法。本文探讨了大型语言模型CPT过程中的学习动态。特别关注通用领域和下游领域性能在每个训练步骤中的演变,其中领域性能通过验证损失来衡量。观察到CPT损失曲线从根本上表征了从一条曲线到另一条隐藏曲线的过渡,并且可以通过解耦分布偏移和学习率退火的影响来描述。推导出结合这两个因素的CPT缩放定律,从而能够预测CPT中任何(持续)训练步骤和跨学习率调度(LRS)的损失。该公式全面理解了CPT中的几个关键因素,包括损失潜力、峰值学习率、训练步骤、重放率等。此外,该方法可以适应于定制针对不同CPT目标的训练超参数,例如平衡通用和领域特定性能。大量实验表明,该缩放定律适用于各种CPT数据集和训练超参数。

🔬 方法详解

问题定义:现有持续预训练方法缺乏对训练过程中学习动态的深入理解,难以预测模型在通用领域和特定领域性能的演变。如何有效地平衡通用能力和领域特定能力,并根据不同的训练目标定制超参数,是一个重要的挑战。

核心思路:论文的核心思路是将CPT过程中的损失曲线分解为分布偏移和学习率退火两个因素的影响,通过建立CPT缩放定律来预测损失变化。该缩放定律能够量化关键因素(如损失潜力、峰值学习率、训练步数、重放率等)对模型性能的影响。

技术框架:该研究主要通过理论分析和实验验证来探索CPT的学习动态。首先,通过观察CPT损失曲线,发现其表征了从一条曲线到另一条隐藏曲线的过渡。然后,通过解耦分布偏移和学习率退火的影响,推导出CPT缩放定律。最后,通过在各种CPT数据集和超参数设置下进行实验,验证该缩放定律的有效性。

关键创新:论文的关键创新在于提出了CPT缩放定律,该定律能够预测CPT过程中任何训练步骤和跨学习率调度的损失。与现有方法相比,该定律能够更全面地理解CPT中的关键因素,并为定制训练策略提供指导。

关键设计:CPT缩放定律的具体形式未知,但摘要中提到它结合了分布偏移和学习率退火两个因素。论文通过实验验证了该定律在不同数据集和超参数下的有效性,并展示了如何利用该定律来定制训练超参数,以平衡通用和领域特定性能。具体的参数设置、损失函数和网络结构等细节在摘要中未提及,属于未知信息。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

论文通过大量实验验证了CPT缩放定律的有效性,表明该定律能够准确预测CPT过程中损失的变化。实验结果表明,该方法可以用于定制训练超参数,以平衡通用和领域特定性能,从而提升模型在特定任务上的性能。具体的性能提升幅度未知。

🎯 应用场景

该研究成果可应用于各种需要持续预训练的大语言模型应用场景,例如金融、医疗、法律等领域。通过CPT缩放定律,可以更有效地进行模型训练,平衡通用能力和领域特定能力,并根据不同的应用需求定制训练策略,从而提升模型在特定任务上的性能。

📄 摘要(原文)

Continual Pre-Training (CPT) has become a popular and effective method to apply strong foundation models to specific downstream tasks. In this work, we explore the learning dynamics throughout the CPT process for large language models. We specifically focus on how general and downstream domain performance evolves at each training step, with domain performance measured via validation losses. We have observed that the CPT loss curve fundamentally characterizes the transition from one curve to another hidden curve, and could be described by decoupling the effects of distribution shift and learning rate annealing. We derive a CPT scaling law that combines the two factors, enabling the prediction of loss at any (continual) training steps and across learning rate schedules (LRS) in CPT. Our formulation presents a comprehensive understanding of several critical factors in CPT, including loss potential, peak learning rate, training steps, replay ratio, etc. Moreover, our approach can be adapted to customize training hyper-parameters to different CPT goals such as balancing general and domain-specific performance. Extensive experiments demonstrate that our scaling law holds across various CPT datasets and training hyper-parameters.