Distill the Best, Ignore the Rest: Improving Dataset Distillation with Loss-Value-Based Pruning

📄 arXiv: 2411.12115v1 📥 PDF

作者: Brian B. Moser, Federico Raue, Tobias C. Nauen, Stanislav Frolov, Andreas Dengel

分类: cs.CV, cs.AI, cs.LG

发布日期: 2024-11-18

DOI: 10.1109/IJCNN64981.2025.11229108.


💡 一句话要点

提出基于损失值剪枝的数据集蒸馏方法,提升泛化性和蒸馏质量。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 数据集蒸馏 损失值剪枝 核心集选择 模型泛化 跨架构鲁棒性

📋 核心要点

  1. 现有数据集蒸馏方法忽略了数据集中可能存在的无益样本,影响了蒸馏效果和泛化能力。
  2. 论文提出“先剪枝,后蒸馏”框架,通过损失值剪枝构建代表性核心集,提高蒸馏质量。
  3. 实验表明,该方法在大量剪枝后仍能显著提升蒸馏质量,并增强了跨架构的鲁棒性。

📝 摘要(中文)

近年来,数据集蒸馏受到了广泛关注,但现有方法通常从整个数据集进行蒸馏,其中可能包含无益的样本。本文提出了一种新颖的“先剪枝,后蒸馏”框架,该框架在蒸馏之前,通过基于损失的采样系统地剪枝数据集。通过在经典蒸馏技术和生成先验之前利用剪枝,我们创建了一个具有代表性的核心集,从而增强了对未见架构的泛化能力——这是当前蒸馏方法的一个重大挑战。更具体地说,我们提出的框架显著提高了蒸馏质量,即使在大量数据集剪枝的情况下(即在蒸馏之前移除原始数据集的 80%),也能实现高达 5.2 个百分点的准确率提升。总的来说,我们的实验结果突出了我们易于采样的优先级和跨架构鲁棒性的优势,为更有效和高质量的数据集蒸馏铺平了道路。

🔬 方法详解

问题定义:数据集蒸馏旨在用一个小的合成数据集来代表原始大数据集,从而加速模型训练和降低存储成本。然而,现有方法通常直接从整个数据集进行蒸馏,忽略了数据集中可能存在的噪声样本或对模型训练无益的样本。这些样本会降低蒸馏效率,并影响蒸馏后模型的泛化能力。因此,如何从原始数据集中选择最具代表性的样本进行蒸馏是一个关键问题。

核心思路:论文的核心思路是“先剪枝,后蒸馏”。即首先通过一种基于损失值的采样方法,从原始数据集中识别并移除对模型训练贡献较小的样本,从而构建一个更具代表性的核心集。然后,再利用经典的数据集蒸馏技术,从这个核心集中生成合成数据集。这种方法可以有效地去除噪声样本,提高蒸馏效率,并增强蒸馏后模型的泛化能力。

技术框架:该框架主要包含两个阶段:剪枝阶段和蒸馏阶段。在剪枝阶段,首先使用一个预训练的模型在原始数据集上进行训练,并计算每个样本的损失值。然后,根据损失值的大小,对样本进行排序,并移除损失值较低的样本,从而得到一个剪枝后的核心集。在蒸馏阶段,使用经典的数据集蒸馏技术,例如匹配训练轨迹或梯度匹配,从核心集中生成合成数据集。

关键创新:该方法最重要的创新点在于提出了“先剪枝,后蒸馏”的框架。与现有方法直接从整个数据集进行蒸馏不同,该方法首先通过损失值剪枝,去除噪声样本,从而提高蒸馏效率和泛化能力。此外,该方法还具有良好的跨架构鲁棒性,即使用一个架构进行剪枝后得到的核心集,可以用于训练其他架构的模型。

关键设计:在剪枝阶段,损失值的计算方式是一个关键设计。论文中使用预训练模型的损失值作为样本重要性的度量。具体来说,可以使用交叉熵损失或均方误差损失等。此外,剪枝比例也是一个重要的参数。论文通过实验分析了不同剪枝比例对蒸馏效果的影响,并给出了合理的建议值。在蒸馏阶段,可以使用各种经典的数据集蒸馏技术,例如匹配训练轨迹或梯度匹配。具体选择哪种技术取决于具体的应用场景和需求。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,该方法在多个数据集上都取得了显著的性能提升。例如,在CIFAR-10数据集上,即使在移除原始数据集的80%后,该方法仍然可以实现高达5.2个百分点的准确率提升。此外,该方法还表现出良好的跨架构鲁棒性,即使用一个架构进行剪枝后得到的核心集,可以用于训练其他架构的模型,并且仍然可以取得良好的性能。

🎯 应用场景

该研究成果可应用于各种需要数据集蒸馏的场景,例如模型压缩、联邦学习、数据隐私保护等。通过先剪枝再蒸馏,可以有效地降低数据集的规模,提高训练效率,并增强模型的泛化能力。此外,该方法还可以用于构建高质量的合成数据集,用于数据增强或模型预训练。

📄 摘要(原文)

Dataset distillation has gained significant interest in recent years, yet existing approaches typically distill from the entire dataset, potentially including non-beneficial samples. We introduce a novel "Prune First, Distill After" framework that systematically prunes datasets via loss-based sampling prior to distillation. By leveraging pruning before classical distillation techniques and generative priors, we create a representative core-set that leads to enhanced generalization for unseen architectures - a significant challenge of current distillation methods. More specifically, our proposed framework significantly boosts distilled quality, achieving up to a 5.2 percentage points accuracy increase even with substantial dataset pruning, i.e., removing 80% of the original dataset prior to distillation. Overall, our experimental results highlight the advantages of our easy-sample prioritization and cross-architecture robustness, paving the way for more effective and high-quality dataset distillation.