MetaDD: Boosting Dataset Distillation with Neural Network Architecture-Invariant Generalization

📄 arXiv: 2410.05103v1 📥 PDF

作者: Yunlong Zhao, Xiaoheng Deng, Xiu Su, Hongyan Xu, Xiuxing Li, Yijing Liu, Shan You

分类: cs.CV

发布日期: 2024-10-07


💡 一句话要点

MetaDD:通过神经网络架构不变泛化提升数据集蒸馏性能

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 数据集蒸馏 元学习 跨架构泛化 神经网络 特征对齐

📋 核心要点

  1. 数据集蒸馏方法受限于特定神经网络架构,跨架构泛化能力差,导致训练其他架构时性能下降。
  2. MetaDD将蒸馏数据分解为架构不变的元特征和架构特定的异构特征,并设计架构不变损失函数。
  3. MetaDD作为轻量级组件,可集成到现有数据集蒸馏方法中,并在跨架构泛化性能上取得显著提升。

📝 摘要(中文)

数据集蒸馏(DD)旨在从大规模数据集中创建精简的蒸馏数据集,以促进高效训练。DD的一个重要挑战是蒸馏数据集与所使用的神经网络(NN)架构之间的依赖关系。使用特定架构蒸馏的数据集训练不同的NN架构通常会导致其他架构的训练性能下降。本文介绍了MetaDD,旨在增强DD在各种NN架构中的泛化能力。具体来说,MetaDD将蒸馏数据划分为元特征(即在不同NN架构中保持一致的数据的共同特征)和异构特征(即每个NN架构的数据的独特特征)。然后,MetaDD采用架构不变的损失函数进行多架构特征对齐,从而增加蒸馏数据中的元特征并减少异构特征。作为一个低内存消耗的组件,MetaDD可以无缝集成到任何DD方法中。实验结果表明,MetaDD显著提高了各种DD方法的性能。在Sre2L(50 IPC)蒸馏的Tiny-Imagenet上,MetaDD实现了高达30.1%的跨架构NN精度,超过了第二好的方法(GLaD)1.7%。

🔬 方法详解

问题定义:数据集蒸馏旨在用小规模数据集替代大规模数据集进行模型训练,以降低计算成本。然而,现有的数据集蒸馏方法通常针对特定神经网络架构进行优化,导致蒸馏出的数据集在其他架构上表现不佳,泛化能力不足。这种架构依赖性限制了数据集蒸馏的实际应用价值。

核心思路:MetaDD的核心思路是将蒸馏数据集中的特征分解为两部分:元特征(Meta Features)和异构特征(Heterogeneous Features)。元特征代表了不同神经网络架构都能学习到的通用特征,而异构特征则代表了特定架构才能学习到的特征。通过增强元特征并抑制异构特征,可以提高蒸馏数据集的跨架构泛化能力。

技术框架:MetaDD可以作为一个独立的模块集成到现有的数据集蒸馏流程中。其主要流程包括:1)使用现有的数据集蒸馏方法生成初始的蒸馏数据集;2)将蒸馏数据集输入到MetaDD模块中;3)MetaDD模块将数据分解为元特征和异构特征,并使用架构不变的损失函数进行优化,以增强元特征并抑制异构特征;4)输出优化后的蒸馏数据集。

关键创新:MetaDD的关键创新在于提出了元特征和异构特征的概念,并设计了相应的架构不变损失函数。该损失函数旨在对齐不同架构下的特征表示,从而使得蒸馏数据集能够更好地泛化到不同的神经网络架构上。这种特征解耦和对齐的思想是MetaDD能够提升跨架构泛化能力的关键。

关键设计:MetaDD的关键设计包括:1)元特征和异构特征的分解方法,具体实现方式未知;2)架构不变损失函数的设计,该损失函数需要能够衡量不同架构下特征表示的相似度,并引导模型学习到架构不变的特征;3)MetaDD模块的集成方式,需要保证其能够无缝集成到现有的数据集蒸馏流程中,并且不会引入过多的计算开销。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

MetaDD在Distilled Tiny-Imagenet数据集上取得了显著的性能提升。在使用Sre2L(50 IPC)进行数据集蒸馏时,MetaDD实现了高达30.1%的跨架构神经网络精度,超过了第二好的方法GLaD 1.7%。实验结果表明,MetaDD能够有效地提高数据集蒸馏的跨架构泛化能力,使其在不同的神经网络架构上都能取得良好的性能。

🎯 应用场景

MetaDD可应用于资源受限场景下的模型训练,例如移动设备或嵌入式系统。通过蒸馏数据集,可以在这些设备上高效地训练模型,而无需访问原始大规模数据集。此外,MetaDD的跨架构泛化能力使得蒸馏数据集可以被用于训练各种不同的模型架构,提高了其通用性和灵活性。未来,MetaDD可以进一步扩展到其他领域,例如联邦学习和持续学习。

📄 摘要(原文)

Dataset distillation (DD) entails creating a refined, compact distilled dataset from a large-scale dataset to facilitate efficient training. A significant challenge in DD is the dependency between the distilled dataset and the neural network (NN) architecture used. Training a different NN architecture with a distilled dataset distilled using a specific architecture often results in diminished trainning performance for other architectures. This paper introduces MetaDD, designed to enhance the generalizability of DD across various NN architectures. Specifically, MetaDD partitions distilled data into meta features (i.e., the data's common characteristics that remain consistent across different NN architectures) and heterogeneous features (i.e., the data's unique feature to each NN architecture). Then, MetaDD employs an architecture-invariant loss function for multi-architecture feature alignment, which increases meta features and reduces heterogeneous features in distilled data. As a low-memory consumption component, MetaDD can be seamlessly integrated into any DD methodology. Experimental results demonstrate that MetaDD significantly improves performance across various DD methods. On the Distilled Tiny-Imagenet with Sre2L (50 IPC), MetaDD achieves cross-architecture NN accuracy of up to 30.1\%, surpassing the second-best method (GLaD) by 1.7\%.