Direct Quantized Training of Language Models with Stochastic Rounding
作者: Kaiyan Zhao, Tsuguchika Tabaru, Kenichi Kobayashi, Takumi Honda, Masafumi Yamazaki, Yoshimasa Tsuruoka
分类: cs.LG, cs.CL
发布日期: 2024-12-06 (更新: 2025-10-10)
备注: Accepted to ACML 2025
💡 一句话要点
提出基于随机舍入的语言模型直接量化训练方法,降低训练时的内存占用。
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 量化训练 低精度计算 大型语言模型 随机舍入 内存优化
📋 核心要点
- 现有量化LLM训练依赖高精度权重进行STE,内存占用大,是核心问题。
- 采用随机舍入直接更新量化低精度权重,无需STE,降低内存需求。
- 实验表明,三元权重训练可行,8比特权重性能媲美BitNet b1.58,且模型具有精度缩放鲁棒性。
📝 摘要(中文)
本文提出了一种语言模型的直接量化训练方法,旨在减少训练过程中的内存占用。与以往依赖Straight-Through Estimation(STE)的方法不同,该方法直接更新量化的低精度权重,从而避免了维护高精度权重副本的需求。具体而言,论文采用随机舍入技术来最小化低比特权重带来的信息损失。在不同规模的LLaMA结构模型上的实验结果表明:(1)即使权重被限制为三元值,仅使用低精度权重进行训练也是可行的;(2)将比特宽度扩展到8比特可以达到与BitNet b1.58相当的性能;(3)该模型对精度缩放和内存减少具有鲁棒性,从FP32迁移到低内存环境(BF16/FP8)时性能下降很小;(4)该模型还支持使用三元权重进行推理,展示了其部署的灵活性。
🔬 方法详解
问题定义:现有的大型语言模型量化训练方法,如BitNet,虽然在部署时可以通过二元或三元权重显著降低内存占用,但在训练阶段仍然需要维护高精度(未量化)的权重,以便使用Straight-Through Estimation(STE)。这种需求导致训练过程中的内存占用仍然很高,限制了更大模型的训练和部署。因此,需要一种能够在训练过程中也降低内存占用的量化训练方法。
核心思路:本文的核心思路是直接更新量化的低精度权重,而无需依赖STE进行反向传播。通过避免维护高精度权重,可以显著降低训练过程中的内存占用。为了减轻低精度量化带来的信息损失,论文采用了随机舍入技术。
技术框架:该方法主要包含以下几个阶段:首先,将模型权重初始化为低精度量化值。然后,在每次迭代中,使用随机舍入技术对梯度进行量化。接着,直接使用量化后的梯度更新量化的权重。整个训练过程都在低精度下进行,无需维护高精度权重。
关键创新:该方法最重要的创新点在于避免了使用STE进行反向传播。STE虽然是一种常用的量化训练方法,但它需要在训练过程中维护高精度权重,增加了内存占用。通过直接更新量化的权重,该方法可以显著降低内存占用,并允许在资源受限的环境中训练更大的模型。与现有方法的本质区别在于,现有方法依赖高精度权重进行梯度计算,而该方法完全在低精度下进行。
关键设计:论文中一个关键的设计是使用随机舍入技术。随机舍入通过引入随机性,可以减少量化误差带来的偏差,从而提高模型的性能。具体的实现方式是,对于一个需要量化的值x,首先计算其量化后的值q(x)。然后,以一定的概率将x舍入到q(x)或q(x)+1。这个概率取决于x与q(x)之间的距离。此外,论文还探索了不同的比特宽度(如三元、8比特)对模型性能的影响。
🖼️ 关键图片
📊 实验亮点
实验结果表明,即使使用三元权重进行训练也是可行的,并且将比特宽度扩展到8比特可以达到与BitNet b1.58相当的性能。更重要的是,该模型对精度缩放和内存减少具有鲁棒性,从FP32迁移到低内存环境(BF16/FP8)时性能下降很小,同时支持三元权重推理。
🎯 应用场景
该研究成果可应用于资源受限的边缘设备或移动设备上部署大型语言模型。通过降低训练和推理过程中的内存占用,使得在这些设备上运行复杂的AI模型成为可能。此外,该方法还可以加速模型训练,并降低训练成本,从而推动AI技术的普及。
📄 摘要(原文)
Although recent quantized Large Language Models (LLMs), such as BitNet, have paved the way for significant reduction in memory usage during deployment with binary or ternary weights, training these models still demands substantial memory footprints. This is partly because high-precision (i.e., unquantized) weights required for straight-through estimation must be maintained throughout the whole training process. To address this, we explore directly updating the quantized low-precision weights without relying on straight-through estimation during backpropagation, aiming to save memory usage during training. Specifically, we employ a stochastic rounding technique to minimize the information loss caused by the use of low-bit weights throughout training. Experimental results on our LLaMA-structured models of various sizes indicate that (1) training with only low-precision weights is feasible even when they are constrained to ternary values; (2) extending the bit width to 8 bits achieves performance on par with BitNet b1.58; (3) our models remain robust to precision scaling and memory reduction, showing minimal performance degradation when moving from FP32 to lower-memory environments (BF16/FP8); and (4) our models also support inference using ternary weights, showcasing their flexibility in deployment.