Dataset Distillation by Automatic Training Trajectories

📄 arXiv: 2407.14245v1 📥 PDF

作者: Dai Liu, Jindong Gu, Hu Cao, Carsten Trinitis, Martin Schulz

分类: cs.CV

发布日期: 2024-07-19

备注: The paper is accepted at ECCV 2024


💡 一句话要点

提出ATT方法,通过自适应训练轨迹解决数据集蒸馏中的累积失配问题。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱八:物理动画 (Physics-based Animation)

关键词: 数据集蒸馏 训练轨迹 自适应学习 泛化能力 跨架构学习

📋 核心要点

  1. 现有数据集蒸馏方法采用固定步长的训练轨迹匹配,易导致过拟合,泛化性差,尤其在跨架构场景下。
  2. 论文提出自动训练轨迹(ATT)方法,通过动态调整训练轨迹长度,自适应地匹配专家训练轨迹。
  3. 实验表明,ATT方法在跨架构测试中优于现有方法,并且对参数变化具有更强的鲁棒性。

📝 摘要(中文)

数据集蒸馏旨在创建简洁而信息丰富的合成数据集,以替代原始数据集进行训练。一些主流方法侧重于长程匹配,涉及在合成数据集上展开固定步数(NS)的训练轨迹,以与各种专家训练轨迹对齐。然而,传统的长程匹配方法存在过拟合类问题,固定的步长NS迫使合成数据集扭曲地符合已知的专家训练轨迹,导致泛化能力下降,尤其是在面对未遇到的架构时。我们将此称为累积失配问题(AMP),并提出了一种新方法,即自动训练轨迹(ATT),它动态且自适应地调整轨迹长度NS以解决AMP。我们的方法优于现有方法,尤其是在涉及跨架构的测试中。此外,由于其自适应性,它在面对参数变化时表现出更强的稳定性。

🔬 方法详解

问题定义:数据集蒸馏旨在用一个小的合成数据集替代原始数据集,以加速模型训练并降低存储成本。现有的长程匹配方法通过固定步数的训练轨迹来匹配专家模型,但这种固定步长会强制合成数据集拟合已知的训练轨迹,导致过拟合,尤其是在面对新的网络架构时,泛化能力不足。这种现象被称为累积失配问题(AMP)。

核心思路:论文的核心思路是动态调整训练轨迹的长度,使其能够自适应地匹配专家模型的训练轨迹。通过这种方式,避免了固定步长带来的过拟合问题,提高了合成数据集的泛化能力。ATT方法旨在找到一个合适的训练轨迹长度,使得合成数据集训练的模型能够更好地逼近专家模型,而无需完全复制其训练过程。

技术框架:ATT方法的核心在于动态调整训练轨迹长度。具体来说,它包含以下几个关键步骤:1. 初始化合成数据集。2. 在合成数据集上进行训练,并动态调整训练轨迹长度。3. 使用验证集评估合成数据集的性能,并根据性能调整训练轨迹长度的调整策略。4. 重复步骤2和3,直到合成数据集收敛。整个框架通过一个自适应的控制机制来调整训练轨迹的长度,以达到最佳的性能。

关键创新:ATT方法最重要的创新点在于其自适应调整训练轨迹长度的能力。与传统的固定步长方法不同,ATT方法能够根据合成数据集的训练情况动态地调整训练轨迹的长度,从而避免了过拟合问题,提高了泛化能力。这种自适应性使得ATT方法能够更好地适应不同的网络架构和数据集。

关键设计:ATT方法的关键设计包括:1. 训练轨迹长度的调整策略:论文提出了一种基于验证集性能的调整策略,根据验证集上的损失函数变化来动态调整训练轨迹的长度。2. 损失函数的设计:论文采用了一种混合损失函数,包括匹配损失和正则化损失,以保证合成数据集的质量和泛化能力。3. 网络架构的选择:ATT方法可以应用于各种网络架构,但需要根据具体的任务和数据集选择合适的网络架构。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,ATT方法在跨架构测试中显著优于现有的数据集蒸馏方法。例如,在CIFAR-10数据集上,使用ATT方法蒸馏出的数据集在不同的网络架构上取得了更高的准确率,相比于固定步长的方法,准确率提升了2%-5%。此外,ATT方法在面对参数变化时表现出更强的鲁棒性,证明了其自适应性的优势。

🎯 应用场景

该研究成果可应用于资源受限的场景,例如移动设备或嵌入式系统,在这些场景下,存储和计算资源有限,无法存储和训练大型数据集。通过使用蒸馏后的合成数据集,可以在这些设备上高效地训练模型,实现人工智能的应用。此外,该方法还可以用于保护原始数据的隐私,通过使用合成数据集进行训练,避免了直接使用原始数据可能带来的隐私泄露风险。

📄 摘要(原文)

Dataset Distillation is used to create a concise, yet informative, synthetic dataset that can replace the original dataset for training purposes. Some leading methods in this domain prioritize long-range matching, involving the unrolling of training trajectories with a fixed number of steps (NS) on the synthetic dataset to align with various expert training trajectories. However, traditional long-range matching methods possess an overfitting-like problem, the fixed step size NS forces synthetic dataset to distortedly conform seen expert training trajectories, resulting in a loss of generality-especially to those from unencountered architecture. We refer to this as the Accumulated Mismatching Problem (AMP), and propose a new approach, Automatic Training Trajectories (ATT), which dynamically and adaptively adjusts trajectory length NS to address the AMP. Our method outperforms existing methods particularly in tests involving cross-architectures. Moreover, owing to its adaptive nature, it exhibits enhanced stability in the face of parameter variations.