TabDiff: a Mixed-type Diffusion Model for Tabular Data Generation

📄 arXiv: 2410.20626v3 📥 PDF

作者: Juntong Shi, Minkai Xu, Harper Hua, Hengrui Zhang, Stefano Ermon, Jure Leskovec

分类: cs.LG

发布日期: 2024-10-27 (更新: 2025-02-16)

期刊: ICLR 2025

🔗 代码/项目: GITHUB


💡 一句话要点

TabDiff:混合类型扩散模型用于表格数据生成,显著提升数据质量。

🎯 匹配领域: 支柱四:生成式动作 (Generative Motion)

关键词: 表格数据生成 扩散模型 混合类型数据 数据增强 隐私保护 缺失值插补 Transformer

📋 核心要点

  1. 表格数据生成面临异构数据类型、复杂相关性和列分布的挑战,现有方法难以有效建模。
  2. TabDiff提出联合扩散框架,通过特征级可学习扩散过程和混合类型采样器,提升生成质量。
  3. 实验表明,TabDiff在多个数据集上显著优于现有方法,尤其在列相关性估计方面提升显著。

📝 摘要(中文)

本文提出TabDiff,一个联合扩散框架,用于建模表格数据中所有混合类型分布。该模型的关键创新在于为数值和类别数据开发了一种联合连续时间扩散过程,并提出了特征级的可学习扩散过程,以应对不同特征分布的高度差异。TabDiff由一个处理不同输入类型的Transformer参数化,并且整个框架可以以端到端的方式高效优化。此外,还引入了一种混合类型随机采样器,以自动校正采样过程中累积的解码误差,并提出了无分类器指导,用于条件缺失列值插补。在七个数据集上的综合实验表明,TabDiff在所有八个指标上都优于现有的竞争基线,在成对列相关性估计方面,比最先进的模型提高了高达22.5%。

🔬 方法详解

问题定义:表格数据生成旨在合成高质量的表格数据,用于数据集增强和隐私保护等任务。然而,表格数据固有的异构数据类型(数值型和类别型)、复杂的列间相关性以及精细的列分布使得开发有效的生成模型极具挑战。现有方法通常难以同时处理这些问题,导致生成的数据质量不高。

核心思路:TabDiff的核心思路是使用一个统一的扩散模型来处理表格数据中混合类型的数据。通过设计一个联合的连续时间扩散过程,将数值型和类别型数据统一到一个框架中进行建模。此外,针对不同特征分布的差异,引入了特征级的可学习扩散过程,使得模型能够更好地适应不同类型的数据。

技术框架:TabDiff的整体框架包括以下几个主要模块:1) 联合扩散过程:定义了数值型和类别型数据的联合扩散和逆扩散过程。2) 特征级可学习扩散过程:为每个特征学习独立的扩散参数,以适应不同的特征分布。3) Transformer参数化:使用Transformer网络来建模数据之间的复杂关系,并处理不同类型的输入。4) 混合类型随机采样器:用于校正采样过程中累积的误差,提高生成数据的质量。5) 无分类器指导:用于条件缺失列值插补,允许模型根据已知列的信息来推断缺失列的值。

关键创新:TabDiff的关键创新在于提出了一个联合的连续时间扩散过程,能够同时处理数值型和类别型数据。与现有方法相比,TabDiff不需要对不同类型的数据进行单独建模,而是使用一个统一的模型来学习所有数据的分布。此外,特征级的可学习扩散过程使得模型能够更好地适应不同特征的分布,从而提高生成数据的质量。

关键设计:TabDiff使用Transformer网络作为扩散模型的参数化器,Transformer的输入包括数值型和类别型特征,并通过embedding层将类别型特征转换为连续向量。损失函数采用标准的扩散模型损失,即预测噪声与真实噪声之间的均方误差。混合类型随机采样器通过引入额外的噪声来校正采样过程中的误差。无分类器指导通过调整生成数据的分布,使其更接近条件分布。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

TabDiff在七个数据集上进行了全面的实验,结果表明其在所有八个指标上都优于现有的竞争基线。尤其是在成对列相关性估计方面,TabDiff比最先进的模型提高了高达22.5%。这些结果表明,TabDiff能够有效地建模表格数据中的复杂关系,并生成高质量的合成数据。

🎯 应用场景

TabDiff在数据集增强、隐私保护、缺失值插补等领域具有广泛的应用前景。它可以用于生成高质量的合成数据,从而扩充训练数据集,提高机器学习模型的性能。此外,TabDiff还可以用于生成匿名化的表格数据,保护用户的隐私。在医疗、金融等领域,TabDiff可以用于模拟真实数据,帮助研究人员进行实验和分析。

📄 摘要(原文)

Synthesizing high-quality tabular data is an important topic in many data science tasks, ranging from dataset augmentation to privacy protection. However, developing expressive generative models for tabular data is challenging due to its inherent heterogeneous data types, complex inter-correlations, and intricate column-wise distributions. In this paper, we introduce TabDiff, a joint diffusion framework that models all mixed-type distributions of tabular data in one model. Our key innovation is the development of a joint continuous-time diffusion process for numerical and categorical data, where we propose feature-wise learnable diffusion processes to counter the high disparity of different feature distributions. TabDiff is parameterized by a transformer handling different input types, and the entire framework can be efficiently optimized in an end-to-end fashion. We further introduce a mixed-type stochastic sampler to automatically correct the accumulated decoding error during sampling, and propose classifier-free guidance for conditional missing column value imputation. Comprehensive experiments on seven datasets demonstrate that TabDiff achieves superior average performance over existing competitive baselines across all eight metrics, with up to $22.5\%$ improvement over the state-of-the-art model on pair-wise column correlation estimations. Code is available at https://github.com/MinkaiXu/TabDiff.