Towards Fully FP8 GEMM LLM Training at Scale

📄 arXiv: 2505.20524v2 📥 PDF

作者: Alejandro Hernández-Cano, Dhia Garbaya, Imanol Schlag, Martin Jaggi

分类: cs.LG

发布日期: 2025-05-26 (更新: 2025-10-24)

备注: 19 pages including appendix


💡 一句话要点

提出全FP8 GEMM LLM训练架构,提升大规模训练吞吐并保持精度

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: FP8训练 低精度计算 大型语言模型 GEMM优化 Transformer架构

📋 核心要点

  1. 现有FP8 LLM训练方法在保证稳定性和吞吐量之间存在trade-off,或精度不足,或效率不高。
  2. 该论文提出一种新的LLM架构,能够在transformer块的所有GEMM计算中使用FP8,提升训练效率。
  3. 实验表明,该方法在提升吞吐量的同时,能够匹配标准BF16训练的下游任务性能。

📝 摘要(中文)

尽管FP8数据格式在大型语言模型(LLM)预训练中具有显著潜力,但由于难以维持大规模训练的稳定性,其应用受到限制。现有方法通常依赖于次优的细粒度FP8内核,或者在敏感组件(如注意力投影)中回退到更高精度的矩阵乘法(GEMM),从而牺牲了潜在的吞吐量增益。我们提出了一种新型LLM架构,首次支持在transformer块内的所有GEMM中进行FP8计算,包括前向和后向传播。这实现了前所未有的吞吐量提升,尤其是在大规模训练中,同时匹配了标准BF16训练的下游性能。我们的架构设计减少了大型异常激活,促进了稳定的长期FP8训练。此外,我们还确定了关键指标来监控低精度训练,并预测潜在的未来发散。

🔬 方法详解

问题定义:现有FP8训练方法难以在保证稳定性的前提下,充分利用FP8的吞吐量优势。部分方法依赖于细粒度FP8内核,效率较低;另一些方法在关键模块回退到更高精度,影响整体性能。因此,如何在LLM训练中实现全FP8 GEMM计算,同时保持训练稳定性和模型精度,是一个亟待解决的问题。

核心思路:该论文的核心思路是通过架构设计来减少大型异常激活,从而促进稳定的FP8训练。通过控制激活值的范围,避免在低精度计算中出现溢出或下溢,从而保证训练过程的稳定性和精度。同时,监控关键指标,预测潜在的训练发散,以便及时调整训练策略。

技术框架:该论文提出了一种新型LLM架构,该架构在transformer块内的所有GEMM计算中都使用FP8数据格式,包括前向和后向传播。整体流程与标准的transformer架构类似,但针对FP8计算进行了优化。具体来说,可能包括对激活函数的调整、对权重初始化策略的改进,以及对梯度缩放技术的应用。

关键创新:该论文的关键创新在于实现了全FP8 GEMM LLM训练,即在transformer块的所有GEMM计算中都使用FP8数据格式。这是首次在LLM训练中实现如此彻底的低精度计算,从而最大化了FP8的吞吐量优势。此外,通过架构设计减少异常激活,保证了训练的稳定性。

关键设计:论文中可能包含以下关键设计:1) 激活值范围控制机制,例如使用特定的激活函数或对激活值进行裁剪;2) 权重初始化策略,例如使用特定的初始化分布或对权重进行缩放;3) 梯度缩放技术,例如动态梯度缩放,以避免梯度消失或爆炸;4) 关键指标监控,例如激活值的最大值、最小值、均值和方差,以及梯度值的统计信息。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

该论文提出的方法实现了前所未有的吞吐量提升,尤其是在大规模训练中。实验结果表明,该方法在保证下游任务性能与标准BF16训练相当的同时,显著提高了训练速度。具体性能数据(例如,吞吐量提升百分比、训练时间缩短比例)未知,但摘要强调了其显著的优势。

🎯 应用场景

该研究成果可广泛应用于大规模语言模型的预训练和微调,尤其是在计算资源受限的场景下,例如边缘设备或移动平台。通过降低计算精度,可以显著减少内存占用和计算时间,从而加速模型开发和部署,并降低成本。此外,该方法也有助于推动低精度计算在其他深度学习领域的应用。

📄 摘要(原文)

Despite the significant potential of FP8 data formats for large language model (LLM) pre-training, their adoption has been limited due to challenges in maintaining stability at scale. Existing approaches often rely on suboptimal fine-grained FP8 kernels or fall back to higher-precision matrix multiplications (GEMMs) in sensitive components, such as attention projections, compromising potential throughput gains. We introduce a new class of LLM architectures that, for the first time, support FP8 computation for all GEMMs within transformer blocks during both forward and backward passes. This enables unprecedented throughput gains, particularly at scale, while matching the downstream performance of standard BF16 training. Our architecture design reduces large outlier activations, promoting stable long-term FP8 training. In addition, we identify key metrics to monitor low-precision training and predict potential future divergences.