Train Small, Infer Large: Memory-Efficient LoRA Training for Large Language Models
作者: Jun Zhang, Jue Wang, Huan Li, Lidan Shou, Ke Chen, Yang You, Guiming Xie, Xuejian Gong, Kunlong Zhou
分类: cs.LG, cs.AI, cs.CL
发布日期: 2025-02-19 (更新: 2025-03-15)
备注: Accepted at ICLR 2025
🔗 代码/项目: GITHUB
💡 一句话要点
LoRAM:通过训练小模型、推理大模型,实现大语言模型的高效LoRA训练
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 大语言模型 低秩适应 LoRA 模型剪枝 内存优化 参数高效微调 持续预训练 量化
📋 核心要点
- LoRA训练LLM时,原始模型参数占用大量内存,限制了在资源有限的设备上进行训练。
- LoRAM通过在剪枝后的小模型上训练LoRA适配器,然后在完整大模型上推理,显著降低了训练时的内存需求。
- 实验表明,LoRAM在多种剪枝策略和下游任务中表现出色,能以更低的硬件成本达到甚至超越现有LoRA方法的性能。
📝 摘要(中文)
大型语言模型(LLMs)在自然语言处理领域取得了显著进展,具有卓越的任务泛化能力。低秩适应(LoRA)提供了一种经济高效的微调解决方案,它冻结原始模型参数,仅训练轻量级的低秩适配器矩阵。然而,LoRA的内存占用主要由原始模型参数决定。为了缓解这个问题,我们提出了LoRAM,这是一种内存高效的LoRA训练方案,其基础是过度参数化的LLM中许多神经元的训练效用较低,但对于推理至关重要。LoRAM提出了一种独特的方案:它在修剪后的(小型)模型上进行训练,以获得修剪后的低秩矩阵,然后恢复这些矩阵并与原始(大型)模型一起用于推理。此外,由模型发布者预先执行的最低成本的持续预训练,可以对齐修剪模型和原始模型之间的知识差异。我们广泛的实验证明了LoRAM在各种剪枝策略和下游任务中的有效性。对于具有700亿参数的模型,LoRAM能够在仅具有20G HBM的GPU上进行训练,从而取代了用于LoRA训练的A100-80G GPU和用于完全微调的15个GPU。具体而言,通过结构化剪枝结合4位量化实现的QLoRAM,对于LLaMA-3.1-70B(LLaMA-2-70B),将低秩矩阵训练中占据内存使用主导地位的参数存储成本降低了15.81倍(16.95倍),同时实现了优于原始LLaMA-3.1-70B(LLaMA-2-70B)和LoRA训练的LLaMA-3.1-8B(LLaMA-2-13B)的显著性能提升。
🔬 方法详解
问题定义:现有LoRA方法在微调大型语言模型时,需要加载完整的原始模型参数,这导致巨大的内存占用,使得在资源受限的设备上(例如,只有20G HBM的GPU)训练大型模型变得不可行。即使是参数高效的LoRA,其内存瓶颈仍然在于原始模型的存储。
核心思路:LoRAM的核心思想是利用大语言模型中的冗余性。许多神经元在训练过程中贡献较小,但对于保持模型的推理能力至关重要。因此,LoRAM首先对原始模型进行剪枝,得到一个更小的模型,然后在该小模型上训练LoRA适配器。训练完成后,将适配器恢复到原始大模型中进行推理。这样,训练过程的内存占用大大降低,而推理性能得以保持。
技术框架:LoRAM的整体框架包含以下几个主要阶段: 1. 模型剪枝:使用某种剪枝策略(例如,结构化剪枝)对原始大语言模型进行剪枝,得到一个参数量更小的模型。 2. 持续预训练 (可选):为了弥补剪枝带来的知识差异,可以对剪枝后的模型进行低成本的持续预训练。 3. LoRA训练:在剪枝后的模型上训练LoRA适配器。由于模型较小,训练所需的内存也大大降低。 4. 适配器恢复:将训练好的LoRA适配器恢复到原始大模型中。 5. 推理:使用原始大模型和LoRA适配器进行推理。
关键创新:LoRAM的关键创新在于将模型剪枝和LoRA训练相结合,实现了在小模型上训练、在大模型上推理的范式。这种方法打破了LoRA训练必须加载完整原始模型的限制,显著降低了内存需求。此外,通过可选的持续预训练步骤,可以进一步提升模型的性能。
关键设计: * 剪枝策略:论文探索了不同的剪枝策略,包括结构化剪枝。结构化剪枝更适合硬件加速,可以进一步提高效率。 * 持续预训练:持续预训练的成本需要控制,以保证整体训练效率。论文可能探索了不同的预训练数据和训练策略。 * 量化:为了进一步降低内存占用,论文采用了4-bit量化技术,即QLoRAM,对模型参数进行量化。
🖼️ 关键图片
📊 实验亮点
LoRAM在700亿参数的LLaMA模型上进行了实验,结果表明,LoRAM能够在仅有20G HBM的GPU上进行训练,取代了原本需要A100-80G GPU的LoRA训练和15个GPU的完整微调。QLoRAM结合结构化剪枝和4位量化,将LLaMA-3.1-70B(LLaMA-2-70B)的参数存储成本降低了15.81倍(16.95倍),同时取得了优于原始LLaMA-3.1-70B(LLaMA-2-70B)和LoRA训练的LLaMA-3.1-8B(LLaMA-2-13B)的性能。
🎯 应用场景
LoRAM技术可广泛应用于各种需要微调大型语言模型的场景,尤其是在计算资源有限的环境中。例如,研究人员和开发者可以使用LoRAM在消费级GPU上微调大型模型,从而降低了AI开发的门槛。此外,LoRAM还可以应用于边缘计算设备,使得这些设备能够运行更强大的AI模型,从而提升智能化水平。
📄 摘要(原文)
Large Language Models (LLMs) have significantly advanced natural language processing with exceptional task generalization capabilities. Low-Rank Adaption (LoRA) offers a cost-effective fine-tuning solution, freezing the original model parameters and training only lightweight, low-rank adapter matrices. However, the memory footprint of LoRA is largely dominated by the original model parameters. To mitigate this, we propose LoRAM, a memory-efficient LoRA training scheme founded on the intuition that many neurons in over-parameterized LLMs have low training utility but are essential for inference. LoRAM presents a unique twist: it trains on a pruned (small) model to obtain pruned low-rank matrices, which are then recovered and utilized with the original (large) model for inference. Additionally, minimal-cost continual pre-training, performed by the model publishers in advance, aligns the knowledge discrepancy between pruned and original models. Our extensive experiments demonstrate the efficacy of LoRAM across various pruning strategies and downstream tasks. For a model with 70 billion parameters, LoRAM enables training on a GPU with only 20G HBM, replacing an A100-80G GPU for LoRA training and 15 GPUs for full fine-tuning. Specifically, QLoRAM implemented by structured pruning combined with 4-bit quantization, for LLaMA-3.1-70B (LLaMA-2-70B), reduces the parameter storage cost that dominates the memory usage in low-rank matrix training by 15.81$\times$ (16.95$\times$), while achieving dominant performance gains over both the original LLaMA-3.1-70B (LLaMA-2-70B) and LoRA-trained LLaMA-3.1-8B (LLaMA-2-13B). Code is available at https://github.com/junzhang-zj/LoRAM.