Learning Generation Orders for Masked Discrete Diffusion Models via Variational Inference

📄 arXiv: 2602.23968v1 📥 PDF

作者: David Fox, Sam Bowyer, Song Liu, Laurence Aitchison, Raul Santos-Rodriguez, Mengyue Yang

分类: cs.LG

发布日期: 2026-02-27

备注: 12 pages, 1 figure


💡 一句话要点

提出基于变分推断的学习框架,优化Masked离散扩散模型的并行生成顺序

🎯 匹配领域: 支柱四:生成式动作 (Generative Motion)

关键词: Masked离散扩散模型 并行生成 变分推断 生成顺序学习 自然语言处理 生成模型 GSM8K数据集

📋 核心要点

  1. Masked离散扩散模型在并行生成方面有优势,但如何平衡并行性和生成质量是一个挑战。
  2. 论文提出基于变分推断的学习框架,优化生成顺序的后验分布,以提升并行生成效率。
  3. 实验表明,该方法在GSM8K数据集上,以更少的生成步骤实现了更高的准确率,具有竞争力。

📝 摘要(中文)

Masked离散扩散模型(MDMs)是一种有前景的生成建模新方法,它能够并行生成token,因此比自回归模型更有效。然而,在并行生成和样本质量之间实现最佳平衡仍然是一个开放的问题。目前的方法主要通过固定的、启发式的并行采样方法来解决这个问题。最近出现了一些基于学习的方法,但从变分推断的角度对其进行公式化仍未得到充分探索。在这项工作中,我们提出了一个变分推断框架,用于学习MDMs的并行生成顺序。作为我们方法的一部分,我们提出了一种生成顺序近似后验的参数化方法,该方法有助于训练期间的并行性和高效采样。使用这种方法,我们对GSM8K数据集进行了初步实验,在高度并行生成的情况下,我们的方法与启发式采样策略相比具有竞争力。例如,我们的方法在平均仅4个生成步骤的情况下实现了33.1%的准确率,而标准竞争方法在相同步骤数下实现了23.7-29.0%的准确率。我们相信对该方法的进一步实验和分析将为MDMs的并行生成问题提供有价值的见解。

🔬 方法详解

问题定义:论文旨在解决Masked离散扩散模型(MDMs)中并行生成顺序优化的问题。现有方法主要依赖固定的启发式策略,缺乏灵活性,难以在并行效率和生成质量之间取得最佳平衡。此外,基于学习的方法还不够成熟,缺乏从变分推断角度的深入研究。

核心思路:论文的核心思路是通过变分推断学习一个生成顺序的近似后验分布。通过优化这个后验分布,模型能够学习到更有效的并行生成策略,从而在保证生成质量的前提下,减少所需的生成步骤。这种方法允许模型根据数据自适应地调整生成顺序,而不是依赖于固定的规则。

技术框架:整体框架包含以下几个主要部分:1) Masked离散扩散模型(MDM):作为生成模型的基础;2) 生成顺序的近似后验分布:使用一个参数化的模型来近似生成顺序的后验分布;3) 变分推断:使用变分推断来优化近似后验分布的参数,目标是最大化数据的对数似然下界(ELBO);4) 采样:在训练完成后,使用学习到的后验分布进行采样,生成并行的token生成顺序。

关键创新:论文的关键创新在于将并行生成顺序的学习问题形式化为变分推断问题,并提出了一种新的参数化方法来表示生成顺序的近似后验分布。这种参数化方法的设计考虑了并行性和高效采样的需求,使得模型能够在训练过程中学习到更有效的并行生成策略。与现有方法相比,该方法能够自适应地调整生成顺序,从而更好地平衡并行效率和生成质量。

关键设计:论文中关于近似后验分布的具体参数化方式以及变分推断的损失函数是关键设计。具体细节未知,摘要中提到“我们提出了一种生成顺序近似后验的参数化方法,该方法有助于训练期间的并行性和高效采样”,但未给出具体公式或网络结构。损失函数的设计目标是最大化ELBO,需要考虑重构损失和KL散度等因素。

📊 实验亮点

实验结果表明,该方法在GSM8K数据集上表现出色,在平均仅4个生成步骤的情况下实现了33.1%的准确率,而标准竞争方法在相同步骤数下仅实现了23.7-29.0%的准确率。这表明该方法能够在高度并行生成的情况下,显著提高生成质量,具有很强的竞争力。

🎯 应用场景

该研究成果可应用于自然语言处理、图像生成等领域,尤其是在需要高效并行生成的场景下,例如机器翻译、文本摘要、图像修复等。通过优化生成顺序,可以显著提高生成速度,降低计算成本,并有望提升生成质量。未来,该方法可以进一步扩展到其他类型的生成模型和任务中。

📄 摘要(原文)

Masked discrete diffusion models (MDMs) are a promising new approach to generative modelling, offering the ability for parallel token generation and therefore greater efficiency than autoregressive counterparts. However, achieving an optimal balance between parallel generation and sample quality remains an open problem. Current approaches primarily address this issue through fixed, heuristic parallel sampling methods. There exist some recent learning based approaches to this problem, but its formulation from the perspective of variational inference remains underexplored. In this work, we propose a variational inference framework for learning parallel generation orders for MDMs. As part of our method, we propose a parameterisation for the approximate posterior of generation orders which facilitates parallelism and efficient sampling during training. Using this method, we conduct preliminary experiments on the GSM8K dataset, where our method performs competitively against heuristic sampling strategies in the regime of highly parallel generation. For example, our method achieves 33.1\% accuracy with an average of only only 4 generation steps, compared to 23.7-29.0\% accuracy achieved by standard competitor methods in the same number of steps. We believe further experiments and analysis of the method will yield valuable insights into the problem of parallel generation with MDMs.