SPAM: Spike-Aware Adam with Momentum Reset for Stable LLM Training

📄 arXiv: 2501.06842v2 📥 PDF

作者: Tianjin Huang, Ziquan Zhu, Gaojie Jin, Lu Liu, Zhangyang Wang, Shiwei Liu

分类: cs.LG, cs.AI, cs.CL

发布日期: 2025-01-12 (更新: 2025-02-28)

🔗 代码/项目: GITHUB


💡 一句话要点

提出SPAM优化器,通过动量重置和梯度裁剪解决LLM训练中的梯度爆炸问题,提升训练稳定性和资源效率。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 大型语言模型 优化器 梯度爆炸 动量重置 梯度裁剪 训练稳定性 内存效率

📋 核心要点

  1. LLM训练面临梯度爆炸问题,导致训练不稳定,需要频繁进行checkpoint恢复和实验重启,效率低下。
  2. SPAM优化器通过动量重置和spike-aware梯度裁剪来抑制梯度尖峰,从而稳定训练过程。
  3. 实验表明,SPAM在LLM预训练、微调、强化学习和时间序列预测等任务上优于Adam及其变体,并具有更高的内存效率。

📝 摘要(中文)

大型语言模型(LLM)在各种任务中表现出卓越的性能,但其训练仍然是资源密集型的,并且容易受到训练不稳定等关键挑战的影响。这种不稳定的一个主要来源是梯度和损失尖峰,它们会扰乱学习过程,通常导致代价高昂的干预措施,如检查点恢复和实验重启,从而进一步加剧效率低下。本文对LLM训练中观察到的梯度尖峰进行了全面的研究,揭示了它们在多种架构和数据集中的普遍性。我们的分析表明,这些尖峰可能比典型梯度大1000倍,从而大大降低模型性能。为了解决这个问题,我们提出了一种新的优化器Spike-Aware Adam with Momentum Reset (SPAM),旨在通过动量重置和spike-aware梯度裁剪来抵消梯度尖峰。大量的实验,包括预训练和微调,表明SPAM在各种任务中始终优于Adam及其变体,包括(1)从60M到1B的LLM预训练,(2)4-bit LLM预训练,(3)强化学习,以及(4)时间序列预测。此外,SPAM通过启用稀疏动量来促进内存高效的训练,其中只维护和更新动量项的子集。当在内存约束下运行时,SPAM优于最先进的内存高效优化器,如GaLore和Adam-Mini。我们的工作强调了减轻LLM训练中梯度尖峰的重要性,并引入了一种有效的优化策略,可以在大规模上提高训练稳定性和资源效率。

🔬 方法详解

问题定义:论文旨在解决大型语言模型(LLM)训练过程中出现的梯度爆炸问题,这种现象会导致训练不稳定,模型性能下降,并增加训练成本。现有的优化器,如Adam,在面对梯度尖峰时表现不佳,需要人工干预,例如checkpoint恢复和实验重启。

核心思路:论文的核心思路是通过一种spike-aware的优化器来减轻梯度尖峰的影响。具体来说,SPAM优化器结合了动量重置和梯度裁剪两种机制。动量重置可以在检测到梯度尖峰时重置动量项,从而避免梯度尖峰对后续更新产生过大的影响。梯度裁剪则限制了梯度的最大值,防止梯度过大。

技术框架:SPAM优化器基于Adam优化器,并在其基础上添加了动量重置和spike-aware梯度裁剪两个模块。整体流程如下:1. 计算梯度;2. 检测梯度尖峰;3. 如果检测到梯度尖峰,则重置动量项;4. 对梯度进行裁剪;5. 使用裁剪后的梯度更新模型参数。

关键创新:SPAM优化器的关键创新在于其spike-aware的设计。传统的梯度裁剪方法通常使用固定的阈值,而SPAM优化器可以根据梯度的统计信息动态调整裁剪阈值。此外,SPAM优化器还支持稀疏动量,即只维护和更新一部分动量项,从而降低内存消耗。

关键设计:SPAM优化器的关键设计包括:1. 梯度尖峰检测:使用梯度值的移动平均和标准差来检测梯度尖峰。2. 动量重置:当检测到梯度尖峰时,将动量项重置为0。3. Spike-aware梯度裁剪:根据梯度值的统计信息动态调整裁剪阈值。4. 稀疏动量:只维护和更新一部分动量项,从而降低内存消耗。具体实现细节可以参考论文提供的代码。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,SPAM优化器在LLM预训练(60M到1B参数)、4-bit LLM预训练、强化学习和时间序列预测等任务上均优于Adam及其变体。在内存受限的情况下,SPAM优化器也优于GaLore和Adam-Mini等内存高效优化器。梯度尖峰可以比典型梯度大1000倍,SPAM能有效缓解该问题。

🎯 应用场景

SPAM优化器可广泛应用于各种需要训练大型语言模型的场景,例如自然语言处理、机器翻译、文本生成等。它能够提高训练的稳定性和效率,降低训练成本,并支持在资源受限的环境下进行训练。该研究成果对于推动大型语言模型的发展和应用具有重要意义。

📄 摘要(原文)

Large Language Models (LLMs) have demonstrated exceptional performance across diverse tasks, yet their training remains highly resource-intensive and susceptible to critical challenges such as training instability. A predominant source of this instability stems from gradient and loss spikes, which disrupt the learning process, often leading to costly interventions like checkpoint recovery and experiment restarts, further amplifying inefficiencies. This paper presents a comprehensive investigation into gradient spikes observed during LLM training, revealing their prevalence across multiple architectures and datasets. Our analysis shows that these spikes can be up to $1000\times$ larger than typical gradients, substantially deteriorating model performance. To address this issue, we propose Spike-Aware Adam with Momentum Reset SPAM, a novel optimizer designed to counteract gradient spikes through momentum reset and spike-aware gradient clipping. Extensive experiments, including both pre-training and fine-tuning, demonstrate that SPAM consistently surpasses Adam and its variants across various tasks, including (1) LLM pre-training from 60M to 1B, (2) 4-bit LLM pre-training,(3) reinforcement learning, and (4) Time Series Forecasting. Additionally, SPAM facilitates memory-efficient training by enabling sparse momentum, where only a subset of momentum terms are maintained and updated. When operating under memory constraints, SPAM outperforms state-of-the-art memory-efficient optimizers such as GaLore and Adam-Mini. Our work underscores the importance of mitigating gradient spikes in LLM training and introduces an effective optimization strategy that enhances both training stability and resource efficiency at scale. Code is available at https://github.com/TianjinYellow/SPAM-Optimizer.git