GRASPrune: Global Gating for Budgeted Structured Pruning of Large Language Models

📄 arXiv: 2604.19398v1 📥 PDF

作者: Ziyang Wang, Jiangfeng Xiao, Chuan Xiao, Ruoxiang Li, Rui Mao, Jianbin Qin

分类: cs.AI

发布日期: 2026-04-21

备注: Accepted to ACL 2026 Main Conference


💡 一句话要点

GRASPrune:面向大语言模型预算约束的全局门控结构化剪枝

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 大语言模型 结构化剪枝 全局门控 模型压缩 推理加速

📋 核心要点

  1. 大语言模型服务成本高昂,参数量、注意力计算和KV缓存带来巨大的内存和延迟开销,现有剪枝方法难以有效平衡模型大小和性能。
  2. GRASPrune通过学习轻量级门控分数,并使用投影直通估计器,在训练的每一步强制执行满足预算约束的硬掩码,实现高效剪枝。
  3. 实验表明,GRASPrune在LLaMA-2-7B上移除50%参数后,在WikiText-2上达到12.18的困惑度,并在多个基准测试上保持了竞争力的零样本准确率。

📝 摘要(中文)

本文提出GRASPrune,一个在预训练后应用的大语言模型结构化剪枝框架,它在单一全局预算下联合剪枝FFN通道和KV头组。GRASPrune学习轻量级的门控分数,并使用投影直通估计器来强制执行硬掩码,在每一步都满足预算约束,同时保持骨干权重冻结,而不是学习无约束的重要性分数并在训练后应用预算。在掩码固定后,我们校准保留单元上的缩放因子,以减轻剪枝引起的尺度不匹配,并将这些因子折叠到剪枝后的权重中,从而获得一个更小的稠密检查点,在推理时无需额外的参数。在LLaMA-2-7B上,GRASPrune移除了50%的参数,并在WikiText-2上实现了12.18的困惑度,同时在五个基准测试上保持了有竞争力的平均零样本准确率,仅使用单个NVIDIA A100 80GB GPU上的512个未标记校准序列,进行了四个epoch的训练,而无需任何完整的模型微调。

🔬 方法详解

问题定义:大语言模型(LLMs)的部署成本很高,主要由于其庞大的参数量、注意力计算以及KV缓存带来的巨大内存和延迟开销。现有的剪枝方法通常先学习无约束的重要性得分,然后在训练后应用预算,这可能导致次优的剪枝结果,并且难以在训练过程中精确控制剪枝比例。

核心思路:GRASPrune的核心思路是在训练过程中,通过全局门控机制,强制执行预算约束下的结构化剪枝。具体来说,它学习轻量级的门控分数,并使用投影直通估计器(Projected Straight-Through Estimator)来生成硬掩码,确保在每一步训练中都满足预设的剪枝比例。这种方法能够在训练早期就引导模型学习对剪枝友好的表示,从而提高剪枝后的模型性能。

技术框架:GRASPrune框架主要包含以下几个阶段:1) 门控分数学习:为FFN通道和KV头组学习轻量级的门控分数。2) 硬掩码生成:使用投影直通估计器,根据门控分数和预算约束生成硬掩码。3) 权重冻结与训练:在训练过程中,保持骨干权重冻结,只更新门控分数。4) 缩放因子校准:在掩码固定后,校准保留单元上的缩放因子,以减轻剪枝引起的尺度不匹配。5) 权重融合:将缩放因子折叠到剪枝后的权重中,得到一个更小的稠密检查点。

关键创新:GRASPrune的关键创新在于其全局门控机制和投影直通估计器的使用。全局门控机制允许在整个模型范围内进行统一的剪枝决策,从而更好地平衡不同层和模块的重要性。投影直通估计器则能够在训练过程中强制执行预算约束,避免了传统方法中先训练再剪枝的次优性。

关键设计:GRASPrune的关键设计包括:1) 门控网络结构:使用轻量级的神经网络来学习门控分数,以减少额外的计算开销。2) 投影直通估计器:使用投影操作将连续的门控分数转换为满足预算约束的硬掩码,并使用直通估计器来传递梯度。3) 缩放因子校准:通过学习缩放因子来补偿剪枝引起的尺度变化,提高剪枝后模型的性能。4) 损失函数设计:使用交叉熵损失函数来训练门控网络,并添加正则化项来鼓励门控分数的稀疏性。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

GRASPrune在LLaMA-2-7B模型上进行了实验,结果表明,在移除50%参数的情况下,GRASPrune在WikiText-2数据集上实现了12.18的困惑度,同时在五个基准测试上保持了具有竞争力的平均零样本准确率。该方法仅使用单个NVIDIA A100 80GB GPU上的512个未标记校准序列,进行了四个epoch的训练,而无需任何完整的模型微调。

🎯 应用场景

GRASPrune具有广泛的应用前景,可用于降低大语言模型在各种场景下的部署成本,例如在资源受限的边缘设备上运行LLMs,或是在云端降低LLMs的推理延迟和内存占用。该方法能够有效减小模型体积,提高推理效率,从而推动LLMs在更多实际应用中的落地。

📄 摘要(原文)

Large language models (LLMs) are expensive to serve because model parameters, attention computation, and KV caches impose substantial memory and latency costs. We present GRASPrune, a structured pruning framework applied after pretraining that jointly prunes FFN channels and KV head groups under a single global budget. Instead of learning importance scores without constraints and applying the budget only after training, GRASPrune learns lightweight gate scores with a projected straight-through estimator that enforces a hard mask satisfying the budget at every step while keeping the backbone weights frozen. After the mask is fixed, we calibrate scaling factors on the retained units to mitigate scale mismatch caused by pruning, and fold these factors into the pruned weights to obtain a smaller dense checkpoint with no extra parameters at inference. On LLaMA-2-7B, GRASPrune removes 50% of parameters and achieves 12.18 perplexity on WikiText-2 while maintaining competitive average zero-shot accuracy on five benchmarks, using four epochs on 512 unlabeled calibration sequences on a single NVIDIA A100 80GB GPU without any full model fine-tuning.