PAT: Pruning-Aware Tuning for Large Language Models

📄 arXiv: 2408.14721v2 📥 PDF

作者: Yijiang Liu, Huanrui Yang, Youxin Chen, Rongyu Zhang, Miao Wang, Yuan Du, Li Du

分类: cs.LG, cs.AI, cs.CL

发布日期: 2024-08-27 (更新: 2025-01-25)

备注: Accepted by AAAI 2025

🔗 代码/项目: GITHUB


💡 一句话要点

提出PAT:一种面向大语言模型的剪枝感知调优方法,提升效率并保持性能。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 大语言模型 剪枝 微调 模型压缩 结构化剪枝

📋 核心要点

  1. 大型语言模型计算开销大,后剪枝微调性能损失明显,难以兼顾效率与性能。
  2. PAT通过混合稀疏化模块(HSM)在微调过程中进行剪枝,保持模型性能的同时降低计算成本。
  3. 实验表明,PAT在Llama2-7b模型上实现了显著的加速和性能提升,优于LoRA等微调方法。

📝 摘要(中文)

大型语言模型(LLMs)在语言任务中表现出色,尤其是在预训练后进行监督微调时。然而,它们巨大的内存和计算需求阻碍了实际应用。结构化剪枝是一种解决方案,它减少了不太重要的权重维度。然而,传统的后置剪枝通常会导致显著的性能损失,并且由于容量减少,进一步微调的恢复效果有限。由于模型微调提炼了预训练模型中的通用和混乱知识,我们旨在将结构化剪枝与微调相结合,并提出剪枝感知调优(PAT)范式,以消除模型冗余,同时最大限度地保持模型性能。具体来说,我们在Attention和FFN组件之间插入创新的混合稀疏化模块(HSM),以相应地稀疏化上游和下游线性模块。HSM包括一个轻量级算子和一个全局共享的可训练掩码。轻量级算子保持与LoRA相当的训练开销,而可训练掩码统一了要稀疏化的通道,确保了结构化剪枝。此外,我们提出了身份损失,它解耦了HSM的变换和缩放属性,以增强训练的鲁棒性。大量的实验表明,PAT在性能和效率方面都表现出色。例如,我们剪枝率为25%的Llama2-7b模型实现了1.33倍的加速,同时在相似的训练成本下,其准确率比LoRA微调模型高出1.26%。

🔬 方法详解

问题定义:现有的大型语言模型(LLMs)在部署时面临着巨大的内存和计算资源需求。传统的后剪枝方法虽然可以减少模型大小,但往往会导致显著的性能下降,并且后续的微调难以完全恢复性能损失。因此,如何在保证模型性能的前提下,有效地降低LLMs的计算成本是一个关键问题。

核心思路:PAT的核心思路是在微调过程中同时进行剪枝,即“剪枝感知调优”。通过在微调过程中引入稀疏化模块,模型可以自适应地学习哪些权重是不重要的,并在训练过程中逐步剪除这些权重。这种方法避免了后剪枝带来的突然性能下降,并允许模型在剪枝的同时进行优化,从而更好地保持性能。

技术框架:PAT在Transformer模型的Attention和FFN模块之间插入混合稀疏化模块(HSM)。HSM包含一个轻量级算子和一个全局共享的可训练掩码。整个训练过程包括正常的微调过程,同时HSM中的可训练掩码也在学习,以确定哪些通道应该被剪枝。

关键创新:PAT的关键创新在于混合稀疏化模块(HSM)的设计和身份损失的引入。HSM通过轻量级算子和全局共享的可训练掩码,实现了结构化剪枝,保证了剪枝后的模型仍然具有良好的结构。身份损失则解耦了HSM的变换和缩放属性,增强了训练的鲁棒性,使得模型在剪枝过程中更加稳定。

关键设计:HSM包含一个轻量级线性变换和一个可训练的mask。这个mask是全局共享的,用于控制哪些通道被剪枝。Identity Loss被设计用来解耦HSM的变换和缩放属性,其具体形式未知(论文未明确给出公式)。训练过程中,模型使用标准的微调损失函数,同时加入Identity Loss来优化HSM。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

PAT在Llama2-7b模型上进行了实验,结果表明,在25%的剪枝率下,PAT实现了1.33倍的加速,同时其准确率比LoRA微调模型高出1.26%。这表明PAT在保持模型性能的同时,有效地降低了计算成本,优于传统的微调方法。

🎯 应用场景

PAT具有广泛的应用前景,可用于降低大型语言模型的部署成本,使其能够在资源受限的设备上运行。例如,可以将PAT应用于移动设备、嵌入式系统或边缘计算环境,从而实现高效的自然语言处理应用。此外,PAT还可以用于加速模型的推理速度,提高用户体验。

📄 摘要(原文)

Large language models (LLMs) excel in language tasks, especially with supervised fine-tuning after pre-training. However, their substantial memory and computational requirements hinder practical applications. Structural pruning, which reduces less significant weight dimensions, is one solution. Yet, traditional post-hoc pruning often leads to significant performance loss, with limited recovery from further fine-tuning due to reduced capacity. Since the model fine-tuning refines the general and chaotic knowledge in pre-trained models, we aim to incorporate structural pruning with the fine-tuning, and propose the Pruning-Aware Tuning (PAT) paradigm to eliminate model redundancy while preserving the model performance to the maximum extend. Specifically, we insert the innovative Hybrid Sparsification Modules (HSMs) between the Attention and FFN components to accordingly sparsify the upstream and downstream linear modules. The HSM comprises a lightweight operator and a globally shared trainable mask. The lightweight operator maintains a training overhead comparable to that of LoRA, while the trainable mask unifies the channels to be sparsified, ensuring structural pruning. Additionally, we propose the Identity Loss which decouples the transformation and scaling properties of the HSMs to enhance training robustness. Extensive experiments demonstrate that PAT excels in both performance and efficiency. For example, our Llama2-7b model with a 25\% pruning ratio achieves 1.33$\times$ speedup while outperforming the LoRA-finetuned model by up to 1.26\% in accuracy with a similar training cost. Code: https://github.com/kriskrisliu/PAT_Pruning-Aware-Tuning