Multi-Granular Node Pruning for Circuit Discovery
作者: Muhammad Umair Haider, Hammad Rizwan, Hassan Sajjad, A. B. Siddique
分类: cs.AI
发布日期: 2025-12-11
💡 一句话要点
提出多粒度节点剪枝方法,用于大规模语言模型中的电路发现,提升效率和精度。
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 电路发现 模型剪枝 多粒度学习 大型语言模型 节点剪枝
📋 核心要点
- 现有电路发现方法计算成本高,且粒度粗糙,忽略了神经元等细粒度结构。
- 提出多粒度节点剪枝框架,通过可学习掩码和稀疏性惩罚实现高效压缩。
- 实验表明,该方法能发现更小的电路,降低内存占用,同时保持任务性能。
📝 摘要(中文)
电路发现旨在识别大型语言模型(LLM)中负责特定行为的最小子网络。现有方法主要依赖于迭代的边剪枝,这种方法计算成本高昂,并且仅限于粗粒度的单元(如注意力头或MLP块),忽略了更精细的结构(如单个神经元)。我们提出了一种用于电路发现的节点级剪枝框架,该框架解决了可扩展性和粒度限制。我们的方法在统一的优化目标中引入了跨多个粒度级别(从整个块到单个神经元)的可学习掩码。特定于粒度的稀疏性惩罚指导剪枝过程,从而可以在单个微调运行中实现全面的压缩。实验结果表明,我们的方法识别出的电路在节点数量上小于先前方法发现的电路;此外,我们证明了粗粒度方法认为重要的许多神经元实际上是不相关的,同时仍然保持了任务性能。此外,我们的方法具有显著更低的内存占用,是现有方法的5-10倍,因为它不需要将中间激活保存在内存中。
🔬 方法详解
问题定义:现有电路发现方法主要依赖于迭代的边剪枝,例如剪除注意力头或MLP块。这种方法计算量大,效率低,并且粒度较粗,无法精确识别模型中真正重要的神经元。因此,需要一种更高效、更细粒度的电路发现方法。
核心思路:论文的核心思路是在节点级别进行剪枝,并引入多粒度的剪枝策略。通过学习不同粒度级别(从整个块到单个神经元)的掩码,可以更精确地识别和去除不重要的节点,从而实现更有效的电路发现。同时,通过引入稀疏性惩罚,可以引导剪枝过程,避免过度剪枝导致性能下降。
技术框架:该方法的核心是一个统一的优化框架,其中包含多个粒度级别的剪枝掩码。这些掩码分别应用于模型中的不同层级,例如整个Transformer块、注意力头、MLP层或单个神经元。在训练过程中,通过优化一个包含任务损失和稀疏性惩罚的联合目标函数,学习这些掩码。最终,根据学习到的掩码,去除不重要的节点,得到精简后的模型。
关键创新:该方法最重要的创新点在于其多粒度的剪枝策略。与传统的粗粒度剪枝方法相比,该方法可以更精确地识别和去除不重要的节点,从而实现更高的压缩率和更好的性能。此外,该方法采用节点级别的剪枝,而不是边级别的剪枝,可以显著降低计算复杂度。
关键设计:该方法使用可学习的掩码来控制每个节点的保留或去除。掩码的值在训练过程中通过梯度下降进行优化。为了避免过度剪枝,引入了稀疏性惩罚,鼓励掩码的值趋向于0或1。具体的损失函数包括任务损失(例如交叉熵损失)和稀疏性惩罚(例如L1正则化)。此外,为了平衡不同粒度级别的剪枝力度,可以为不同粒度的掩码设置不同的稀疏性惩罚系数。
🖼️ 关键图片
📊 实验亮点
实验结果表明,该方法能够发现比现有方法更小的电路,同时保持任务性能。与现有方法相比,该方法具有显著更低的内存占用(5-10倍)。实验还证明,粗粒度方法认为重要的许多神经元实际上是不相关的,进一步验证了该方法在细粒度剪枝方面的优势。
🎯 应用场景
该研究成果可应用于大型语言模型的压缩和加速,降低模型部署成本,提升推理效率。通过识别关键电路,有助于理解模型内部机制,为模型优化和改进提供指导。此外,该方法还可应用于其他深度学习模型的压缩和加速。
📄 摘要(原文)
Circuit discovery aims to identify minimal subnetworks that are responsible for specific behaviors in large language models (LLMs). Existing approaches primarily rely on iterative edge pruning, which is computationally expensive and limited to coarse-grained units such as attention heads or MLP blocks, overlooking finer structures like individual neurons. We propose a node-level pruning framework for circuit discovery that addresses both scalability and granularity limitations. Our method introduces learnable masks across multiple levels of granularity, from entire blocks to individual neurons, within a unified optimization objective. Granularity-specific sparsity penalties guide the pruning process, allowing a comprehensive compression in a single fine-tuning run. Empirically, our approach identifies circuits that are smaller in nodes than those discovered by prior methods; moreover, we demonstrate that many neurons deemed important by coarse methods are actually irrelevant, while still maintaining task performance. Furthermore, our method has a significantly lower memory footprint, 5-10x, as it does not require keeping intermediate activations in the memory to work.