Saliency-driven Dynamic Token Pruning for Large Language Models
作者: Yao Tao, Yehui Tang, Yun Wang, Mingjian Zhu, Hailin Hu, Yunhe Wang
分类: cs.CL, cs.AI
发布日期: 2025-04-06 (更新: 2025-04-09)
💡 一句话要点
提出基于显著性的动态Token剪枝方法SDTP,加速LLM长序列推理。
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 大语言模型 token剪枝 显著性驱动 动态剪枝 长序列推理 模型加速 模型压缩
📋 核心要点
- LLM在长序列推理中面临注意力机制带来的计算复杂度挑战,现有方法难以兼顾效率与性能。
- SDTP通过显著性驱动的预测模块,动态评估并剪枝冗余token,降低计算量,提升推理速度。
- 实验表明,SDTP在多种模型和数据集上有效,可在保持性能的同时显著减少FLOPs并加速推理。
📝 摘要(中文)
大型语言模型(LLM)在长序列推理场景中面临注意力机制的二次计算复杂度挑战。受神经网络模型中特征归因可解释性理论的启发,本文观察到并非所有token的贡献相同。基于此,提出了一种新的token剪枝框架,即基于显著性的动态Token剪枝(SDTP),以根据输入上下文逐步动态地剪枝冗余token。具体而言,设计了一个轻量级的显著性驱动预测模块,以估计每个token的重要性得分,并将其添加到LLM的不同层以分层剪枝冗余token。此外,提出了一种基于排序的优化策略,以最小化显著性得分和预测重要性得分的排序差异。大量实验表明,该框架可推广到各种模型和数据集。通过分层剪枝65%的输入token,该方法显著减少了33%~47%的FLOPs,并在推理期间实现了高达1.75倍的加速,同时保持了相当的性能。进一步证明,SDTP可以与KV缓存压缩方法结合使用,以实现进一步的压缩。
🔬 方法详解
问题定义:大型语言模型在处理长序列时,由于注意力机制的二次方复杂度,计算成本非常高昂。现有的token剪枝方法可能无法准确识别和去除冗余token,导致性能下降或加速效果不佳。因此,如何在保证模型性能的前提下,高效地剪枝冗余token,降低计算复杂度,是本文要解决的核心问题。
核心思路:本文的核心思路是基于token的显著性进行动态剪枝。作者观察到并非所有token对最终的预测贡献相同,因此可以通过评估每个token的重要性,并去除贡献较小的token来降低计算量。通过设计一个轻量级的显著性预测模块,可以动态地根据输入上下文评估token的重要性,从而实现更有效的剪枝。
技术框架:SDTP框架主要包含以下几个模块:1) 嵌入层:将输入文本转换为token嵌入表示。2) LLM层:使用LLM提取token的上下文信息。3) 显著性预测模块:基于LLM的隐藏状态,预测每个token的重要性得分。4) 剪枝模块:根据重要性得分,动态地剪枝冗余token。5) 输出层:基于剪枝后的token表示,进行最终的预测。该框架在LLM的不同层级添加显著性预测模块,实现分层剪枝。
关键创新:SDTP的关键创新在于:1) 提出了显著性驱动的动态token剪枝方法,能够根据输入上下文动态地评估token的重要性,并进行剪枝。2) 设计了一个轻量级的显著性预测模块,可以高效地预测token的重要性得分。3) 提出了基于排序的优化策略,最小化显著性得分和预测重要性得分的排序差异,提高剪枝的准确性。
关键设计:显著性预测模块是一个轻量级的前馈神经网络,输入是LLM的隐藏状态,输出是每个token的重要性得分。损失函数采用基于排序的损失函数,例如RankNet loss或Margin Ranking Loss,以最小化显著性得分和预测重要性得分的排序差异。剪枝比例是一个超参数,可以根据不同的模型和数据集进行调整。作者通过实验发现,剪枝65%的token可以在保持性能的同时显著降低计算量。
🖼️ 关键图片
📊 实验亮点
实验结果表明,SDTP在各种模型和数据集上都表现出良好的性能。通过分层剪枝65%的输入token,该方法显著减少了33%~47%的FLOPs,并在推理期间实现了高达1.75倍的加速,同时保持了与原始模型相当的性能。SDTP还可以与KV缓存压缩方法结合使用,以实现进一步的压缩。
🎯 应用场景
该研究成果可广泛应用于需要处理长序列的LLM应用场景,例如机器翻译、文本摘要、问答系统等。通过降低计算复杂度,可以加速推理过程,降低部署成本,并使得LLM能够在资源受限的设备上运行。此外,该方法还可以与其他模型压缩技术结合使用,进一步提高模型的效率。
📄 摘要(原文)
Despite the recent success of large language models (LLMs), LLMs are particularly challenging in long-sequence inference scenarios due to the quadratic computational complexity of the attention mechanism. Inspired by the interpretability theory of feature attribution in neural network models, we observe that not all tokens have the same contribution. Based on this observation, we propose a novel token pruning framework, namely Saliency-driven Dynamic Token Pruning (SDTP), to gradually and dynamically prune redundant tokens based on the input context. Specifically, a lightweight saliency-driven prediction module is designed to estimate the importance score of each token with its hidden state, which is added to different layers of the LLM to hierarchically prune redundant tokens. Furthermore, a ranking-based optimization strategy is proposed to minimize the ranking divergence of the saliency score and the predicted importance score. Extensive experiments have shown that our framework is generalizable to various models and datasets. By hierarchically pruning 65\% of the input tokens, our method greatly reduces 33\% $\sim$ 47\% FLOPs and achieves speedup up to 1.75$\times$ during inference, while maintaining comparable performance. We further demonstrate that SDTP can be combined with KV cache compression method for further compression.