Dataset Distillation via Adversarial Prediction Matching

📄 arXiv: 2312.08912v1 📥 PDF

作者: Mingyang Chen, Bo Huang, Junda Lu, Bing Li, Yi Wang, Minhao Cheng, Wei Wang

分类: cs.CV

发布日期: 2023-12-14


💡 一句话要点

提出对抗预测匹配的数据集蒸馏方法,高效压缩数据集并保持模型性能。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 数据集蒸馏 对抗学习 模型压缩 知识迁移 单层优化

📋 核心要点

  1. 现有数据集蒸馏方法通常涉及复杂的嵌套优化或长程梯度展开,计算成本高昂,内存效率低。
  2. 论文提出对抗预测匹配方法,通过最小化模型在真实数据上的预测差异来实现数据集蒸馏,简化了优化过程。
  3. 实验表明,该方法在ImageNet-1K等数据集上表现出色,仅用少量内存即可生成高性能的蒸馏数据集。

📝 摘要(中文)

本文提出了一种新的数据集蒸馏方法,通过最小化分别在原始大数据集和小型蒸馏数据集上训练的模型在真实数据分布上的预测差异,从而将原始数据的信息压缩到蒸馏数据集中。我们提出了一个对抗框架来高效地解决这个问题。与现有的涉及嵌套优化或长程梯度展开的蒸馏方法不同,我们的方法依赖于单层优化,确保了内存效率,并在时间和内存预算之间提供了灵活的权衡,从而可以使用最少6.5GB的GPU内存来蒸馏ImageNet-1K。在最佳权衡策略下,与最先进的方法相比,我们的方法所需的内存减少了2.5倍,运行时间减少了5倍。实验结果表明,我们的方法可以生成只有原始数据集10%大小的合成数据集,但平均可以达到在完整原始数据集上训练的模型测试精度的94%,显著超过了现有技术水平。此外,大量的测试表明,我们蒸馏的数据集在跨架构泛化能力方面表现出色。

🔬 方法详解

问题定义:数据集蒸馏旨在从大型原始数据集中合成更小的、压缩的数据集,同时保留必要的信息以维持模型性能。现有方法的痛点在于计算复杂度高,内存需求大,尤其是在处理大规模数据集时,嵌套优化和长程梯度展开导致训练过程极其耗时和资源密集。

核心思路:论文的核心思路是将数据集蒸馏问题转化为一个对抗学习问题。通过最小化在原始数据集和蒸馏数据集上训练的模型在真实数据分布上的预测差异,迫使蒸馏数据集捕获原始数据集的关键信息。这种方法避免了复杂的嵌套优化,简化了训练过程。

技术框架:整体框架包含两个主要部分:一个在原始数据集上训练的模型(教师模型)和一个在蒸馏数据集上训练的模型(学生模型)。对抗训练的目标是让学生模型的预测尽可能接近教师模型的预测。具体流程如下:1) 从原始数据集中采样真实数据;2) 使用教师模型和学生模型分别对真实数据进行预测;3) 计算两个模型预测之间的差异(例如,KL散度);4) 使用对抗损失更新蒸馏数据集,使其能够生成更接近教师模型预测的结果。

关键创新:最重要的技术创新点在于将数据集蒸馏问题转化为一个单层对抗学习问题,避免了嵌套优化和长程梯度展开。这显著降低了计算复杂度和内存需求,使得在大规模数据集上进行数据集蒸馏成为可能。与现有方法的本质区别在于优化目标和优化方式。现有方法通常直接优化蒸馏数据集,使其能够最大化在蒸馏数据集上训练的模型的性能,而本文方法则侧重于最小化模型预测的差异。

关键设计:关键设计包括:1) 使用KL散度作为预测差异的度量;2) 使用Adam优化器更新蒸馏数据集;3) 在时间和内存预算之间进行权衡,通过调整蒸馏数据集的大小和训练迭代次数来优化性能。此外,论文还探索了不同的网络结构和超参数设置,以进一步提高蒸馏数据集的性能。

📊 实验亮点

实验结果表明,该方法在ImageNet-1K数据集上取得了显著的性能提升。使用仅为原始数据集10%大小的蒸馏数据集,可以达到在完整原始数据集上训练的模型测试精度的94%,显著超过了现有技术水平。此外,该方法在内存效率方面也表现出色,可以使用最少6.5GB的GPU内存来蒸馏ImageNet-1K,与最先进的方法相比,所需的内存减少了2.5倍,运行时间减少了5倍。

🎯 应用场景

该研究成果可广泛应用于资源受限的场景,例如移动设备、嵌入式系统和边缘计算。通过使用蒸馏数据集,可以在这些设备上部署高性能的机器学习模型,而无需存储和处理大型原始数据集。此外,该方法还可以用于数据隐私保护,通过发布蒸馏数据集而非原始数据,可以降低数据泄露的风险。

📄 摘要(原文)

Dataset distillation is the technique of synthesizing smaller condensed datasets from large original datasets while retaining necessary information to persist the effect. In this paper, we approach the dataset distillation problem from a novel perspective: we regard minimizing the prediction discrepancy on the real data distribution between models, which are respectively trained on the large original dataset and on the small distilled dataset, as a conduit for condensing information from the raw data into the distilled version. An adversarial framework is proposed to solve the problem efficiently. In contrast to existing distillation methods involving nested optimization or long-range gradient unrolling, our approach hinges on single-level optimization. This ensures the memory efficiency of our method and provides a flexible tradeoff between time and memory budgets, allowing us to distil ImageNet-1K using a minimum of only 6.5GB of GPU memory. Under the optimal tradeoff strategy, it requires only 2.5$\times$ less memory and 5$\times$ less runtime compared to the state-of-the-art. Empirically, our method can produce synthetic datasets just 10% the size of the original, yet achieve, on average, 94% of the test accuracy of models trained on the full original datasets including ImageNet-1K, significantly surpassing state-of-the-art. Additionally, extensive tests reveal that our distilled datasets excel in cross-architecture generalization capabilities.