Understanding and Accelerating the Training of Masked Diffusion Language Models
作者: Chunsan Hong, Sanghyun Lee, Chieh-Hsin Lai, Satoshi Hayakawa, Yuhta Takida, Yuki Mitsufuji, Seungryong Kim, Jong Chul Ye
分类: cs.LG, cs.AI, cs.CL
发布日期: 2026-05-13
备注: Preprint
💡 一句话要点
提出钟形时间采样策略,加速Masked Diffusion语言模型的训练。
🎯 匹配领域: 支柱四:生成式动作 (Generative Motion)
关键词: Masked Diffusion模型 语言建模 训练加速 钟形时间采样 局部性偏差
📋 核心要点
- Masked Diffusion模型训练缓慢,限制了其在大规模语言建模中的应用。
- 提出钟形时间采样策略,通过调整训练过程中噪声水平的采样分布,缓解局部性偏差的影响。
- 实验表明,该方法在多个benchmark上显著加速了MDM的训练,同时保持或提升了模型性能。
📝 摘要(中文)
Masked diffusion模型(MDMs)作为语言建模中自回归模型(ARMs)的一种有前景的替代方案而出现。然而,众所周知,MDMs的学习速度明显慢于ARMs,这在将MDMs扩展到更大的模型时可能会成为问题。因此,我们提出以下问题:如何在保持其最终性能的同时加速标准MDM训练?为此,我们首先详细分析了MDM训练速度慢的原因。我们发现主要因素是语言的局部性偏差:token的预测信息集中在附近的位置。我们进一步研究了这种偏差如何减缓学习,并提出了一个简单而有效的补救措施:钟形时间采样作为一种训练策略。值得注意的是,使用我们的训练方法训练的MDMs在One Billion Word Benchmark (LM1B)上达到相同的验证负对数似然(NLL)的速度比标准训练快约4倍。我们还展示了在各种基准测试中生成困惑度、零样本困惑度和下游任务性能方面的更快改进。
🔬 方法详解
问题定义:论文旨在解决Masked Diffusion Model (MDM) 训练速度慢的问题。现有的MDM训练方法收敛速度远低于自回归模型(ARM),这阻碍了MDM在更大规模数据集和模型上的应用。MDM训练慢的主要原因是语言的局部性偏差,即预测一个token所需的信息主要集中在附近的token上。
核心思路:论文的核心思路是调整MDM训练过程中噪声水平的采样分布,使其更关注对模型学习更有价值的时间步。具体来说,论文提出了一种钟形时间采样策略,该策略在训练初期和末期采样更多的噪声水平,而在中间阶段采样较少的噪声水平。这种策略旨在缓解局部性偏差的影响,并加速模型的学习。
技术框架:论文提出的方法主要集中在MDM的训练阶段,不涉及对MDM模型结构的修改。训练流程如下: 1. 数据准备:使用标准的语言建模数据集,如One Billion Word Benchmark (LM1B)。 2. 模型初始化:初始化一个标准的MDM模型。 3. 钟形时间采样:在每个训练步骤中,根据钟形分布采样一个时间步t。 4. 噪声添加:根据采样的时间步t,向输入文本添加噪声。 5. 模型预测:MDM模型预测原始文本。 6. 损失计算:计算模型预测与原始文本之间的损失。 7. 模型更新:使用梯度下降法更新模型参数。
关键创新:论文的关键创新在于提出了钟形时间采样策略。与传统的均匀时间采样相比,钟形时间采样能够更有效地利用训练数据,并加速模型的学习。这种方法简单有效,不需要对MDM模型结构进行任何修改。
关键设计:钟形时间采样的具体实现方式是使用一个高斯分布来定义时间步的采样概率。高斯分布的均值位于时间步的中间位置,标准差是一个可调节的超参数。通过调整标准差,可以控制钟形曲线的形状,从而影响训练过程中不同噪声水平的采样频率。论文中没有明确提及损失函数的具体形式,但通常使用负对数似然(NLL)作为MDM的损失函数。
🖼️ 关键图片
📊 实验亮点
实验结果表明,使用钟形时间采样策略训练的MDMs在One Billion Word Benchmark (LM1B)上达到相同的验证负对数似然(NLL)的速度比标准训练快约4倍。此外,该方法还在生成困惑度、零样本困惑度和下游任务性能方面取得了显著的提升。
🎯 应用场景
该研究成果可应用于各种自然语言处理任务,如文本生成、机器翻译、文本摘要等。通过加速Masked Diffusion模型的训练,可以降低训练成本,并使其能够应用于更大规模的数据集和模型,从而提升模型性能。此外,该方法还可以推广到其他类型的扩散模型,如图像生成等。
📄 摘要(原文)
Masked diffusion models (MDMs) have emerged as a promising alternative to autoregressive models (ARMs) for language modeling. However, MDMs are known to learn substantially more slowly than ARMs, which may become problematic when scaling MDMs to larger models. Therefore, we ask the following question: how can we accelerate standard MDM training while maintaining its final performance? To this end, we first provide a detailed analysis of why MDM training is slow. We find that the main factor is the locality bias of language: the predictive information for a token is concentrated in nearby positions. We further investigate how this bias slows learning and suggest a simple yet effective remedy: bell-shaped time sampling as a training strategy. Notably, MDMs trained with our training recipe reach the same validation negative log-likelihood (NLL) up to $\sim4\times$ faster than standard training on One Billion Word Benchmark (LM1B). We also show faster improvements in generative perplexity, zero-shot perplexity, and downstream task performance on various benchmarks.