Improving Discrete Diffusion Unmasking Policies Beyond Explicit Reference Policies

📄 arXiv: 2510.05725v1 📥 PDF

作者: Chunsan Hong, Seonho An, Min-Soo Kim, Jong Chul Ye

分类: cs.LG, cs.AI, cs.CL

发布日期: 2025-10-07

备注: Preprint


💡 一句话要点

提出基于KL正则化MDP的离散扩散模型Unmasking策略学习方法,显著提升性能。

🎯 匹配领域: 支柱四:生成式动作 (Generative Motion)

关键词: 离散扩散模型 掩码语言模型 马尔可夫决策过程 强化学习 策略优化 KL正则化 Unmasking策略 文本生成

📋 核心要点

  1. 现有Masked Diffusion Models的性能高度依赖于手工设计的Unmasking策略,缺乏自适应性。
  2. 将Unmasking过程建模为KL正则化的马尔可夫决策过程,通过学习优化策略,提升模型性能。
  3. 实验表明,该方法在多个基准测试中显著优于传统Max-Confidence策略,尤其在SUDOKU数据集上提升显著。

📝 摘要(中文)

掩码扩散模型(MDMs)作为一种新型语言建模框架,通过迭代地去噪被掩盖的序列来生成句子,逐步填充[MASK]标记。MDMs支持任意顺序采样,但性能对下一个要unmask的位置选择高度敏感。先前的工作通常依赖于基于规则的策略(例如,最大置信度、最大边际),这些策略提供了临时性的改进。本文提出用学习到的调度器来替代这些启发式方法。具体来说,将去噪过程建模为一个具有显式参考策略的KL正则化马尔可夫决策过程(MDP),并优化一个正则化目标,该目标在标准假设下允许策略改进和收敛保证。证明了在该框架下优化的策略比启发式策略更接近数据分布。在四个基准测试中,学习到的策略始终优于最大置信度策略:例如,在unmask顺序至关重要的SUDOKU数据集上,性能比随机策略高20.1%,比最大置信度策略高11.2%。

🔬 方法详解

问题定义:论文旨在解决离散扩散模型中,如何选择最佳的unmasking顺序的问题。现有方法,如最大置信度或最大边际,依赖于人工设计的启发式规则,缺乏自适应性,无法充分利用数据中的信息,导致模型性能受限。

核心思路:论文的核心思路是将unmasking过程建模为一个马尔可夫决策过程(MDP),并利用强化学习的方法来学习一个最优的unmasking策略。通过引入KL正则化项,鼓励学习到的策略接近一个参考策略,从而保证训练的稳定性和收敛性。

技术框架:整体框架包含以下几个关键部分:1) 状态空间:表示当前被mask的序列状态。2) 动作空间:表示选择哪个位置进行unmask。3) 奖励函数:用于评估unmasking动作的好坏,通常基于模型预测的置信度或与真实值的差距。4) 参考策略:一个预定义的unmasking策略,例如随机策略或最大置信度策略。5) KL正则化项:用于约束学习到的策略与参考策略之间的差异。通过优化一个包含奖励和KL正则化项的目标函数,学习一个最优的unmasking策略。

关键创新:最重要的创新在于将unmasking策略的学习问题形式化为一个KL正则化的MDP。与传统的启发式方法相比,该方法能够自适应地学习最优的unmasking策略,从而更好地利用数据中的信息。此外,KL正则化项的使用保证了训练的稳定性和收敛性。

关键设计:论文中,状态空间定义为当前被mask的序列,动作空间定义为选择哪个位置进行unmask。奖励函数可以基于模型预测的置信度或与真实值的差距来设计。参考策略可以选择随机策略或最大置信度策略。KL正则化项的系数需要根据具体任务进行调整。目标函数通常使用策略梯度方法进行优化。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,该方法在四个基准测试中 consistently 优于 max-confidence 策略。例如,在 SUDOKU 数据集上,该方法比随机策略高 20.1%,比 max-confidence 策略高 11.2%。这些结果表明,学习到的 unmasking 策略能够更好地利用数据中的信息,从而提高模型性能。

🎯 应用场景

该研究成果可应用于各种自然语言处理任务,例如文本生成、机器翻译、代码生成等。通过学习最优的unmasking策略,可以提高生成文本的质量和流畅度,提升模型的性能。此外,该方法还可以应用于其他离散序列生成任务,例如图像生成、音频生成等。

📄 摘要(原文)

Masked diffusion models (MDMs) have recently emerged as a novel framework for language modeling. MDMs generate sentences by iteratively denoising masked sequences, filling in [MASK] tokens step by step. Although MDMs support any-order sampling, performance is highly sensitive to the choice of which position to unmask next. Prior work typically relies on rule-based schedules (e.g., max-confidence, max-margin), which provide ad hoc improvements. In contrast, we replace these heuristics with a learned scheduler. Specifically, we cast denoising as a KL-regularized Markov decision process (MDP) with an explicit reference policy and optimize a regularized objective that admits policy improvement and convergence guarantees under standard assumptions. We prove that the optimized policy under this framework generates samples that more closely match the data distribution than heuristic schedules. Empirically, across four benchmarks, our learned policy consistently outperforms max-confidence: for example, on SUDOKU, where unmasking order is critical, it yields a 20.1% gain over random and a 11.2% gain over max-confidence.