Dataset Distillation with Probabilistic Latent Features
作者: Zhe Li, Sarah Cechnicka, Cheng Ouyang, Katharina Breininger, Peter Schüffler, Bernhard Kainz
分类: cs.CV
发布日期: 2025-05-10 (更新: 2025-05-17)
备注: 23 pages
💡 一句话要点
提出基于概率潜在特征的数据集蒸馏方法,提升跨架构泛化性能。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 数据集蒸馏 概率建模 潜在特征学习 生成模型 跨架构泛化
📋 核心要点
- 现有数据集蒸馏方法依赖像素空间到潜在空间的映射,缺乏对潜在特征空间结构的有效建模。
- 提出一种随机方法,建模潜在特征的联合分布,捕捉空间结构并生成多样化样本。
- 实验表明,该方法在多个数据集上实现了最先进的跨架构性能,验证了其通用性和有效性。
📝 摘要(中文)
随着深度学习模型复杂度和训练数据量的增长,降低存储和计算成本变得越来越重要。数据集蒸馏通过合成一组紧凑的合成数据来有效替代原始数据集,从而应对这一挑战。现有方法通常依赖于将数据从像素空间映射到生成模型的潜在空间,而我们提出了一种新的随机方法,该方法对潜在特征的联合分布进行建模。这使得我们的方法能够更好地捕捉空间结构并产生多样化的合成样本,从而有益于模型训练。具体来说,我们引入了一个由轻量级网络参数化的低秩多元正态分布。这种设计保持了较低的计算复杂度,并且与数据集蒸馏中使用的各种匹配网络兼容。蒸馏后,通过将学习到的潜在特征输入到预训练的生成器中来生成合成图像。然后,这些合成图像用于训练分类模型,并在真实测试集上评估性能。我们在包括ImageNet子集、CIFAR-10和MedMNIST组织病理学数据集在内的多个基准上验证了我们的方法。我们的方法在一系列骨干架构上实现了最先进的跨架构性能,证明了其通用性和有效性。
🔬 方法详解
问题定义:数据集蒸馏旨在用一个远小于原始数据集的合成数据集,训练出性能接近甚至超过原始数据集训练的模型。现有方法,特别是基于生成模型的方法,通常直接将像素空间的数据映射到生成模型的潜在空间,缺乏对潜在特征之间关系的建模,导致合成数据多样性不足,泛化能力受限。
核心思路:本论文的核心思路是显式地对潜在特征的联合概率分布进行建模,从而更好地捕捉潜在特征之间的空间结构,并生成更多样化的合成数据。通过学习潜在特征的分布,可以避免直接从像素空间进行映射,从而提高合成数据的质量和泛化能力。
技术框架:该方法主要包含两个阶段:1) 潜在特征分布学习阶段:使用一个轻量级网络参数化的低秩多元正态分布来建模潜在特征的联合分布。该网络以匹配网络的输出作为输入,学习分布的参数。2) 合成图像生成和模型训练阶段:从学习到的潜在特征分布中采样,将采样得到的潜在特征输入到预训练的生成器中,生成合成图像。然后,使用这些合成图像训练下游的分类模型。
关键创新:该方法最重要的技术创新点在于对潜在特征的概率建模。与以往直接映射像素空间到潜在空间的方法不同,该方法显式地学习潜在特征的联合分布,从而更好地捕捉潜在特征之间的关系,并生成更多样化的合成数据。这种概率建模方法能够提高合成数据的质量和泛化能力,从而提升下游分类模型的性能。
关键设计:关键设计包括:1) 使用低秩多元正态分布建模潜在特征,降低计算复杂度;2) 使用轻量级网络参数化该分布,方便与各种匹配网络结合;3) 使用预训练的生成器生成高质量的合成图像;4) 使用标准的分类损失函数训练下游分类模型。
🖼️ 关键图片
📊 实验亮点
该方法在ImageNet子集、CIFAR-10和MedMNIST等多个基准数据集上进行了验证,并取得了state-of-the-art的跨架构性能。实验结果表明,该方法在不同的骨干网络架构下均能有效提升分类模型的性能,证明了其通用性和有效性。具体性能提升数据未知,但原文强调了“state-of-the-art cross architecture performance”。
🎯 应用场景
该研究成果可应用于数据隐私保护、模型压缩和加速等领域。通过数据集蒸馏,可以在保护原始数据隐私的同时,生成可用于模型训练的合成数据。此外,该方法可以用于减少训练数据量,降低存储和计算成本,加速模型训练过程。该技术在医疗影像分析、自动驾驶等数据敏感且计算资源受限的领域具有广阔的应用前景。
📄 摘要(原文)
As deep learning models grow in complexity and the volume of training data increases, reducing storage and computational costs becomes increasingly important. Dataset distillation addresses this challenge by synthesizing a compact set of synthetic data that can effectively replace the original dataset in downstream classification tasks. While existing methods typically rely on mapping data from pixel space to the latent space of a generative model, we propose a novel stochastic approach that models the joint distribution of latent features. This allows our method to better capture spatial structures and produce diverse synthetic samples, which benefits model training. Specifically, we introduce a low-rank multivariate normal distribution parameterized by a lightweight network. This design maintains low computational complexity and is compatible with various matching networks used in dataset distillation. After distillation, synthetic images are generated by feeding the learned latent features into a pretrained generator. These synthetic images are then used to train classification models, and performance is evaluated on real test set. We validate our method on several benchmarks, including ImageNet subsets, CIFAR-10, and the MedMNIST histopathological dataset. Our approach achieves state-of-the-art cross architecture performance across a range of backbone architectures, demonstrating its generality and effectiveness.