TAPIOCA: Why Task- Aware Pruning Improves OOD model Capability
作者: Krish Sharma, Omar Naim, Soumadeep Saha, Nicholas Asher
分类: cs.LG, cs.AI
发布日期: 2026-05-14
💡 一句话要点
任务感知剪枝提升模型泛化能力,改善OOD数据表现
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 任务感知剪枝 异分布泛化 模型鲁棒性 表征几何 分布偏移
📋 核心要点
- 现有任务感知剪枝方法在特定任务上表现出性能提升,但其内在原因尚不明确。
- 论文提出任务感知剪枝通过校正OOD数据引起的几何扭曲,使其与ID数据的任务适应几何结构对齐,从而提升泛化能力。
- 实验结果表明,任务感知剪枝在ID数据上无明显优势,但在OOD数据上能持续提高准确性,并在不同模型规模上保持一致。
📝 摘要(中文)
本文研究了任务感知层剪枝(task-aware layer pruning)提升模型性能的原因。研究表明,在可控的多项式回归任务和大型语言模型中,这种剪枝方法在同分布(ID)数据上没有明显优势,但能持续提高异分布(OOD)数据的准确性。实证分析表明,OOD输入会引起层级的范数和成对距离分布偏离ID数据的分布。据此,论文提出了任务感知剪枝的几何解释:每个任务都诱导出一个任务适应的几何结构,其特征在于ID输入上的表征分布。OOD输入会扭曲这种几何结构。任务感知剪枝识别并移除那些产生或放大这种扭曲的层,从而使OOD表征的范数和成对距离向ID数据靠拢,最终改善模型在OOD数据上的性能。论文通过受控的分布偏移和残差缩放干预提供了因果证据,并在不同模型规模上验证了一致性。
🔬 方法详解
问题定义:现有研究表明,任务感知剪枝可以提升模型在特定任务上的性能,但缺乏对其有效性的深入解释。尤其是在模型面对与训练数据分布不同的OOD数据时,任务感知剪枝为何能够提升泛化能力仍然是一个待解决的问题。现有方法未能充分解释剪枝如何影响模型对OOD数据的表征和预测。
核心思路:论文的核心思路是,每个任务都会在模型的表征空间中诱导出一个特定的几何结构,而OOD数据会扭曲这种几何结构。任务感知剪枝通过移除那些加剧这种扭曲的层,使得OOD数据的表征更接近ID数据的表征,从而提升模型在OOD数据上的性能。这种方法的核心在于识别并移除对OOD数据产生负面影响的层。
技术框架:论文采用了一种实证研究的方法,结合理论分析来验证其核心思路。首先,在可控的多项式回归任务和大型语言模型上进行实验,观察任务感知剪枝在ID和OOD数据上的表现。然后,通过分析ID和OOD数据的层级范数和成对距离分布,揭示OOD数据引起的几何扭曲。最后,通过受控的分布偏移和残差缩放干预,验证任务感知剪枝的因果效应。
关键创新:论文最重要的技术创新点在于提出了任务感知剪枝的几何解释。与以往关注剪枝对模型容量或计算效率的影响不同,论文从表征几何的角度解释了剪枝如何影响模型对OOD数据的泛化能力。这种几何解释为理解和改进剪枝算法提供了新的视角。
关键设计:论文的关键设计包括:1) 使用多项式回归任务和大型语言模型作为实验对象,以验证结论的普适性;2) 通过分析层级范数和成对距离分布来量化ID和OOD数据的表征差异;3) 设计受控的分布偏移和残差缩放干预,以验证任务感知剪枝的因果效应。具体的参数设置和网络结构根据不同的实验对象进行调整,以保证实验的有效性和可靠性。
🖼️ 关键图片
📊 实验亮点
实验结果表明,任务感知剪枝在ID数据上没有明显优势,但在OOD数据上能持续提高准确性。例如,在大型语言模型上,任务感知剪枝可以显著提升模型在分布外数据集上的困惑度(perplexity),并且这种提升在不同模型规模上保持一致。通过受控的分布偏移和残差缩放干预,论文提供了因果证据,进一步验证了任务感知剪枝的有效性。
🎯 应用场景
该研究成果可应用于提升机器学习模型在实际部署中的鲁棒性和泛化能力,尤其是在面对数据分布偏移的场景下。例如,在自动驾驶、医疗诊断等领域,模型需要处理各种未知的、分布外的数据,任务感知剪枝可以帮助模型更好地适应这些数据,提高安全性和可靠性。此外,该研究也为模型剪枝算法的设计提供了新的思路。
📄 摘要(原文)
Recent work has promoted task-aware layer pruning as a way to improve model performance on particular tasks, as shown by TALE. In this paper, we investigate when such improvements occur and why. We show first that, across controlled polynomial regression tasks and large language models, such pruning yields no benefit on in-distribution (ID) data but consistently improves out-of-distribution (OOD) accuracy. We further show empirically that OOD inputs induce layerwise norm and pairwise-distance profiles that deviate from the corresponding ID profiles. This leads to a geometric explanation of task-aware pruning: each task induces a task-adapted geometry, characterized empirically by the representation profiles observed on ID inputs. OOD inputs can introduce a distorted version of the task-adapted geometry. Task-aware pruning identifies layers that create or amplify this distortion; by removing them, it shifts OOD representational norms and pairwise distances toward those observed on the adapted distribution. This realigns OOD inputs with the model's task-adapted geometry and improves performance. We provide causal evidence through controlled distribution shifts and residual-scaling interventions, and demonstrate consistent behavior across model scales.