A Multi-Power Law for Loss Curve Prediction Across Learning Rate Schedules
作者: Kairong Luo, Haodong Wen, Shengding Hu, Zhenbo Sun, Zhiyuan Liu, Maosong Sun, Kaifeng Lyu, Wenguang Chen
分类: cs.LG, cs.AI, cs.CL, stat.ML
发布日期: 2025-03-17
💡 一句话要点
提出一种多幂律模型,用于预测不同学习率策略下大语言模型的损失曲线。
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 大语言模型 学习率策略 损失曲线预测 多幂律模型 预训练 超参数优化 模型训练
📋 核心要点
- 现有方法难以准确预测不同学习率策略下大模型的损失曲线,阻碍了高效的模型训练。
- 提出一种多幂律模型,结合学习率总和的幂律和学习率衰减的额外幂律,更准确地描述损失曲线。
- 实验表明,该模型能准确预测未见学习率策略的损失曲线,并自动发现优于余弦策略的学习率策略。
📝 摘要(中文)
训练大型模型耗费资源和时间,因此理解模型性能与超参数之间的定量关系至关重要。本文提出了一种经验定律,描述了大型语言模型在不同学习率策略(如常数、余弦和阶梯衰减策略)下的预训练损失演变。该定律采用多幂律形式,结合了基于学习率总和的幂律,以及额外的幂律来解释学习率衰减引起的损失减少效应。我们在各种模型大小和架构上广泛验证了该定律,并证明在少量学习率策略上进行拟合后,该定律可以准确预测不同形状和范围的未见策略的损失曲线。此外,通过最小化不同学习率策略下的预测最终预训练损失,我们能够找到一种优于广泛使用的余弦学习率策略的策略。有趣的是,这种自动发现的策略与最近提出的 Warmup-Stable-Decay (WSD) 策略 (Hu et al, 2024) 有些相似,但实现了略低的最终损失。我们相信这些结果可以为理解预训练的动态和设计学习率策略以提高效率提供有价值的见解。
🔬 方法详解
问题定义:现有方法难以准确预测不同学习率策略下,特别是学习率衰减策略对大语言模型预训练损失曲线的影响。这使得选择和优化学习率策略变得困难,导致训练效率低下,浪费计算资源。现有的方法通常依赖于经验或简单的启发式规则,缺乏对损失曲线演变的定量理解。
核心思路:论文的核心思路是将损失曲线的演变建模为一个多幂律函数,该函数不仅考虑了学习率的总和对损失的影响,还考虑了学习率衰减带来的额外损失减少效应。通过拟合少量不同学习率策略下的损失曲线,可以学习到该多幂律函数的参数,从而预测其他未见学习率策略下的损失曲线。这种方法能够更准确地捕捉学习率策略对损失的影响,从而指导学习率策略的选择和优化。
技术框架:该方法主要包含以下几个阶段: 1. 数据收集:在不同的学习率策略下训练大语言模型,并记录预训练过程中的损失曲线。 2. 模型拟合:使用收集到的损失曲线数据,拟合提出的多幂律模型,得到模型的参数。 3. 损失曲线预测:使用拟合好的模型,预测未见学习率策略下的损失曲线。 4. 学习率策略优化:通过最小化预测的最终预训练损失,自动搜索最优的学习率策略。
关键创新:该论文最重要的技术创新点在于提出了多幂律模型,该模型能够同时考虑学习率总和和学习率衰减对损失曲线的影响。与现有方法相比,该模型能够更准确地描述损失曲线的演变,从而实现更准确的损失曲线预测和更有效的学习率策略优化。现有方法通常只考虑学习率总和的影响,忽略了学习率衰减带来的额外损失减少效应。
关键设计:多幂律模型的核心公式包含两部分:一部分是基于学习率总和的幂律,另一部分是基于学习率衰减的额外幂律。学习率总和的幂律描述了学习率总和与损失之间的关系,而学习率衰减的额外幂律则描述了学习率衰减带来的额外损失减少效应。模型的参数通过最小化预测损失与实际损失之间的差异来学习。论文还探索了不同的学习率策略,包括常数、余弦和阶梯衰减策略,并验证了该模型在不同模型大小和架构上的有效性。
🖼️ 关键图片
📊 实验亮点
实验结果表明,该多幂律模型能够准确预测不同形状和范围的未见学习率策略的损失曲线。通过最小化预测的最终预训练损失,自动发现的学习率策略优于广泛使用的余弦学习率策略,并与Warmup-Stable-Decay (WSD) 策略相似,但实现了略低的最终损失。这些结果验证了该模型的有效性和实用性。
🎯 应用场景
该研究成果可应用于大语言模型的预训练,帮助研究人员和工程师更高效地选择和优化学习率策略,从而节省计算资源和时间。此外,该方法还可以推广到其他深度学习模型的训练中,提高模型训练的效率和性能。该研究对于推动人工智能领域的发展具有重要的实际价值和未来影响。
📄 摘要(原文)
Training large models is both resource-intensive and time-consuming, making it crucial to understand the quantitative relationship between model performance and hyperparameters. In this paper, we present an empirical law that describes how the pretraining loss of large language models evolves under different learning rate schedules, such as constant, cosine, and step decay schedules. Our proposed law takes a multi-power form, combining a power law based on the sum of learning rates and additional power laws to account for a loss reduction effect induced by learning rate decay. We extensively validate this law on various model sizes and architectures, and demonstrate that after fitting on a few learning rate schedules, the law accurately predicts the loss curves for unseen schedules of different shapes and horizons. Moreover, by minimizing the predicted final pretraining loss across learning rate schedules, we are able to find a schedule that outperforms the widely used cosine learning rate schedule. Interestingly, this automatically discovered schedule bears some resemblance to the recently proposed Warmup-Stable-Decay (WSD) schedule (Hu et al, 2024) but achieves a slightly lower final loss. We believe these results could offer valuable insights for understanding the dynamics of pretraining and designing learning rate schedules to improve efficiency.