An Extra RMSNorm is All You Need for Fine Tuning to 1.58 Bits
作者: Cody Steinmetz, Gavin Childress, Aaron Herbst, Gavin Jones, Jasdeep Singh, Eli Vang, Keagan Weinstock
分类: cs.LG, cs.AI, cs.CL
发布日期: 2025-05-12
💡 一句话要点
仅需额外RMSNorm即可微调至1.58比特量化精度
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 三元量化 低比特量化 大型语言模型 RMSNorm 模型微调 量化感知训练 Transformer 模型压缩
📋 核心要点
- 现有LLM量化方法,如后训练量化和量化感知训练,在精度和训练成本上存在权衡。
- 该论文提出在每个线性投影前添加RMSNorm,并采用渐进式逐层量化策略,实现稳定的三元量化微调。
- 实验表明,该方法在标准语言建模基准上,性能匹配或超越了更复杂的知识蒸馏方法。
📝 摘要(中文)
大型语言模型(LLMs)改变了自然语言处理领域,但其规模使其在现实世界中的部署成本高昂。后训练量化减少了内存和计算开销,但通常会降低精度,而量化感知训练可以通过额外的训练来恢复性能。将量化推向三元(2比特)状态可以节省更多资源,但众所周知其稳定性较差。基于最近的研究表明,使用直通估计的无偏置、RMS归一化Transformer可以达到1.58比特的精度,我们证明了简单地在每个线性投影之前插入RMS归一化,并应用渐进的、逐层的量化策略,可以稳定地将全精度检查点微调为三元LLM。我们的方法在标准语言建模基准测试中匹配或超过了更复杂的知识蒸馏流程,而没有增加模型复杂性。这些结果表明,仅通过仔细的归一化就可以缩小三元LLM和全精度LLM之间的大部分精度差距,从而使超低比特推理成为可能。
🔬 方法详解
问题定义:论文旨在解决将大型语言模型(LLMs)量化到极低比特(如三元,即2比特)时,精度显著下降且训练不稳定的问题。现有的量化方法,例如后训练量化,虽然可以降低计算和存储成本,但会严重损害模型性能。量化感知训练虽然可以恢复性能,但需要额外的训练开销,并且在极低比特量化时仍然面临稳定性挑战。
核心思路:论文的核心思路是,通过在Transformer模型的每个线性投影层之前插入RMSNorm(Root Mean Square Normalization)层,并结合渐进的、逐层的量化策略,来稳定地进行三元量化微调。RMSNorm能够有效地控制模型内部激活值的尺度,从而提高量化过程的稳定性。渐进式量化则避免了模型参数一次性剧烈变化,进一步提升了微调的稳定性。
技术框架:该方法基于标准的Transformer架构,主要修改在于:1) 在每个线性投影层(例如,全连接层、卷积层)之前添加一个RMSNorm层;2) 采用渐进的、逐层的量化策略。具体来说,首先对模型的某些层进行量化,然后逐步增加量化的层数,直到所有层都被量化为三元。整个流程从一个预训练好的全精度LLM开始,然后使用提出的方法进行微调。
关键创新:该方法最重要的创新点在于发现简单的RMSNorm的添加,结合渐进式量化,就能显著提升极低比特量化的稳定性和性能。与以往需要复杂的知识蒸馏或其他技巧的方法相比,该方法更加简洁有效,且不需要额外的模型复杂度。
关键设计:关键设计包括:1) RMSNorm的具体实现,通常采用标准RMSNorm公式,对输入进行归一化;2) 渐进式量化的具体schedule,例如,可以按照层数或模块的重要性进行排序,逐步增加量化层数;3) 量化函数的选择,可以使用straight-through estimator (STE) 来近似梯度,从而使得量化操作可以进行反向传播。
📊 实验亮点
实验结果表明,该方法在标准语言建模基准测试中,能够以1.58比特的精度匹配或超过更复杂的知识蒸馏流程,而无需增加模型复杂度。这表明,通过仔细的归一化和量化策略,可以显著缩小三元LLM和全精度LLM之间的性能差距,使得超低比特推理成为可能。
🎯 应用场景
该研究成果可广泛应用于资源受限的场景,例如移动设备、边缘计算和嵌入式系统。通过将大型语言模型量化到极低比特,可以在这些平台上部署更强大的AI应用,例如智能助手、机器翻译和文本生成,同时降低功耗和存储成本。此外,该方法还可以加速LLM的推理速度,提升用户体验。
📄 摘要(原文)
Large language models (LLMs) have transformed natural-language processing, yet their scale makes real-world deployment costly. Post-training quantization reduces memory and computation but often degrades accuracy, while quantization-aware training can recover performance at the cost of extra training. Pushing quantization to the ternary (2-bit) regime yields even larger savings but is notoriously unstable. Building on recent work showing that a bias-free, RMS-normalized Transformer with straight-through estimation can reach 1.58-bit precision, we demonstrate that simply inserting RMS normalization before every linear projection and applying a gradual, layer-wise quantization schedule stably fine-tunes full-precision checkpoints into ternary LLMs. Our approach matches or surpasses more elaborate knowledge-distillation pipelines on standard language-modeling benchmarks without adding model complexity. These results indicate that careful normalization alone can close much of the accuracy gap between ternary and full-precision LLMs, making ultra-low-bit inference practical.