Mitigating Bias in Dataset Distillation

📄 arXiv: 2406.06609v2 📥 PDF

作者: Justin Cui, Ruochen Wang, Yuanhao Xiong, Cho-Jui Hsieh

分类: cs.LG, cs.AI, cs.CV

发布日期: 2024-06-06 (更新: 2024-07-10)

备注: ICML


💡 一句话要点

提出基于核密度估计的重加权方法,缓解数据集蒸馏中的偏差放大问题

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 数据集蒸馏 偏差缓解 核密度估计 样本重加权 数据增强

📋 核心要点

  1. 数据集蒸馏在压缩大数据集时会放大原始数据集中的偏差,导致下游模型性能下降,尤其是在存在颜色和背景偏差时。
  2. 论文提出一种基于核密度估计的样本重加权方法,通过调整样本权重来减少蒸馏过程中偏差的放大效应。
  3. 实验表明,该方法在多个数据集上显著优于现有方法,例如在CMNIST数据集上,精度提升高达67.7%。

📝 摘要(中文)

数据集蒸馏是一种将大型数据集压缩成小型合成数据集的技术,便于下游训练任务。本文研究了原始数据集中的偏差对数据集蒸馏性能的影响。通过对具有颜色、损坏和背景偏差的典型数据集进行全面的经验评估,发现原始数据集中的颜色和背景偏差会通过蒸馏过程放大,导致在蒸馏数据集上训练的模型的性能显著下降,而损坏偏差则通过蒸馏过程被抑制。为了减少数据集蒸馏中的偏差放大,我们提出了一种简单而有效的基于核密度估计的样本重加权方案。在多个真实和合成数据集上的实验结果表明了该方法的有效性。特别是在偏差冲突率为5%和IPC为50的CMNIST数据集上,我们的方法达到了91.5%的测试精度,而vanilla DM只有23.8%,性能提升了67.7%,而将最先进的去偏方法应用于同一数据集仅达到53.7%的精度。我们的研究结果强调了解决数据集蒸馏中偏差的重要性,并为解决该过程中的偏差放大问题提供了一个有希望的途径。

🔬 方法详解

问题定义:数据集蒸馏旨在将大型数据集压缩成一个小的合成数据集,用于训练模型。然而,如果原始数据集中存在偏差(例如,颜色偏差、背景偏差),数据集蒸馏过程可能会放大这些偏差,导致在合成数据集上训练的模型泛化能力差,尤其是在与偏差相关的对抗性样本上表现不佳。现有方法通常忽略了数据集蒸馏过程中偏差放大的问题,或者直接应用通用的去偏方法,效果有限。

核心思路:论文的核心思路是利用核密度估计(Kernel Density Estimation, KDE)来估计数据集中样本的密度,并根据样本密度进行重加权。密度较低的样本往往代表着少数类别或者更具挑战性的样本,因此赋予更高的权重,从而平衡数据集中的偏差。这种重加权策略旨在减少蒸馏过程中对多数类别或简单样本的过度拟合,从而缓解偏差放大问题。

技术框架:该方法主要包含以下几个步骤:1) 使用原始数据集训练一个初始模型;2) 利用训练好的模型提取原始数据集中每个样本的特征向量;3) 使用核密度估计方法,基于提取的特征向量,估计每个样本的密度;4) 根据样本密度计算重加权系数,密度低的样本赋予更高的权重;5) 使用重加权后的样本进行数据集蒸馏,生成合成数据集。

关键创新:该方法的关键创新在于将核密度估计与样本重加权相结合,用于缓解数据集蒸馏中的偏差放大问题。与传统的去偏方法不同,该方法不是直接修改模型或损失函数,而是通过调整样本权重来影响蒸馏过程,从而更有效地平衡数据集中的偏差。此外,该方法简单易实现,可以与其他数据集蒸馏方法相结合。

关键设计:在核密度估计中,需要选择合适的核函数和带宽参数。论文中可能使用了高斯核函数,并采用交叉验证等方法选择最优的带宽参数。重加权系数的计算方式可能为样本密度的倒数,或者经过归一化处理后的密度倒数。在数据集蒸馏过程中,可以使用常见的蒸馏损失函数,例如交叉熵损失函数或KL散度损失函数。具体的网络结构和超参数设置需要根据具体的数据集和任务进行调整。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

该方法在CMNIST数据集上取得了显著的性能提升,在5%偏差冲突率和IPC 50的设置下,测试精度达到了91.5%,相比于vanilla DM的23.8%提升了67.7%,并且优于现有的去偏方法(53.7%)。这表明该方法能够有效地缓解数据集蒸馏中的偏差放大问题,并显著提高模型的泛化能力。实验结果充分验证了该方法的有效性和优越性。

🎯 应用场景

该研究成果可应用于各种需要数据集蒸馏的场景,尤其是在原始数据集存在偏差的情况下。例如,在自动驾驶领域,如果训练数据集中某些场景或天气条件下的数据较少,使用该方法可以生成更平衡的合成数据集,提高自动驾驶系统的鲁棒性和安全性。此外,该方法还可以应用于医疗图像分析、人脸识别等领域,提高模型在各种偏差条件下的泛化能力。

📄 摘要(原文)

Dataset Distillation has emerged as a technique for compressing large datasets into smaller synthetic counterparts, facilitating downstream training tasks. In this paper, we study the impact of bias inside the original dataset on the performance of dataset distillation. With a comprehensive empirical evaluation on canonical datasets with color, corruption and background biases, we found that color and background biases in the original dataset will be amplified through the distillation process, resulting in a notable decline in the performance of models trained on the distilled dataset, while corruption bias is suppressed through the distillation process. To reduce bias amplification in dataset distillation, we introduce a simple yet highly effective approach based on a sample reweighting scheme utilizing kernel density estimation. Empirical results on multiple real-world and synthetic datasets demonstrate the effectiveness of the proposed method. Notably, on CMNIST with 5% bias-conflict ratio and IPC 50, our method achieves 91.5% test accuracy compared to 23.8% from vanilla DM, boosting the performance by 67.7%, whereas applying state-of-the-art debiasing method on the same dataset only achieves 53.7% accuracy. Our findings highlight the importance of addressing biases in dataset distillation and provide a promising avenue to address bias amplification in the process.