Scaling up Masked Diffusion Models on Text
作者: Shen Nie, Fengqi Zhu, Chao Du, Tianyu Pang, Qian Liu, Guangtao Zeng, Min Lin, Chongxuan Li
分类: cs.AI, cs.CL, cs.LG
发布日期: 2024-10-24 (更新: 2025-02-28)
🔗 代码/项目: GITHUB
💡 一句话要点
提出可扩展的Masked Diffusion模型,在文本生成和理解任务上达到媲美自回归模型的效果。
🎯 匹配领域: 支柱四:生成式动作 (Generative Motion)
关键词: Masked Diffusion模型 文本生成 语言理解 扩展定律 无监督学习
📋 核心要点
- 现有Masked Diffusion模型在语言建模中潜力巨大,但在文本生成和语言理解等核心任务中的可扩展性和有效性有待探索。
- 论文提出了一种可扩展的Masked Diffusion模型,并利用无监督无分类器指导方法,有效利用大规模非配对数据,提升条件推理性能。
- 实验结果表明,该模型在语言理解和文本生成任务上均表现出色,甚至超越了更大规模的自回归模型,并打破了反向诅咒现象。
📝 摘要(中文)
本文针对Masked Diffusion模型(MDM)在语言建模中的潜力,探索了其在文本生成和语言理解等核心任务中的可扩展性和有效性。研究建立了MDM的首个扩展定律,表明其扩展速率与自回归模型(ARM)相当,且计算差距相对较小。基于此,训练了一系列参数高达11亿的MDM,系统评估了它们与同等或更大规模ARM的性能。充分利用MDM的概率公式,提出了一种简单而有效的无监督无分类器指导方法,有效利用大规模非配对数据,提升了条件推理的性能。在语言理解方面,11亿参数的MDM在八个零样本基准测试中的四个上优于在相同数据上训练的11亿参数TinyLlama模型。值得注意的是,它在GSM8K数据集上实现了与70亿参数Llama-2模型相当的数学推理能力。在文本生成方面,MDM通过16倍的预训练时间,在性能上与使用KV-Cache加速采样的ARM相匹配,同时采样速度快1.4倍,实现了灵活的权衡。此外,MDM通过有效地处理双向推理和适应数据中的时间变化,解决了ARM面临的挑战性任务。值得注意的是,一个11亿参数的MDM打破了更大的ARM(如130亿参数Llama-2和1750亿参数GPT-3)在更多数据和计算下遇到的反向诅咒。
🔬 方法详解
问题定义:现有自回归模型(ARM)在文本生成和语言理解任务中占据主导地位,但Masked Diffusion模型(MDM)作为一种新兴的生成模型,其潜力尚未被充分挖掘。MDM的可扩展性、在核心语言任务中的有效性以及与ARM的性能差距是需要解决的关键问题。此外,如何有效利用MDM的概率特性来提升其性能也是一个挑战。
核心思路:论文的核心思路是探索MDM的扩展规律,并训练大规模的MDM模型,以系统地评估其在文本生成和语言理解任务中的性能。同时,利用MDM的概率公式,设计一种无监督无分类器指导方法,以提升条件推理的性能。通过这种方式,充分发挥MDM的优势,使其在特定任务上能够与甚至超越ARM。
技术框架:整体框架包括以下几个主要步骤:1) 建立MDM的扩展定律,分析其与ARM的扩展速率和计算差距;2) 训练一系列不同规模的MDM模型,最大规模达到11亿参数;3) 提出无监督无分类器指导方法,利用大规模非配对数据提升条件推理性能;4) 在多个语言理解和文本生成基准测试上评估MDM的性能,并与ARM进行对比;5) 分析MDM在处理双向推理和适应数据时间变化方面的优势。
关键创新:论文最重要的技术创新点在于:1) 建立了MDM的首个扩展定律,为MDM的规模化训练提供了理论指导;2) 提出了一种简单而有效的无监督无分类器指导方法,充分利用了MDM的概率特性,提升了条件推理的性能;3) 证明了MDM在特定任务上可以超越更大规模的ARM,并打破了反向诅咒现象。
关键设计:论文的关键设计包括:1) 模型架构的选择:采用Transformer架构作为MDM的基础模型;2) 训练数据的选择:使用大规模的文本数据集进行预训练;3) 损失函数的设计:采用标准的扩散模型损失函数;4) 无监督无分类器指导方法的具体实现:通过调整采样过程中的噪声水平来实现指导;5) 实验参数的设置:根据模型规模和任务特点,调整学习率、batch size等超参数。
🖼️ 关键图片
📊 实验亮点
实验结果表明,11亿参数的MDM在语言理解任务中,在八个零样本基准测试中的四个上优于11亿参数的TinyLlama模型。在GSM8K数据集上,MDM实现了与70亿参数Llama-2模型相当的数学推理能力。在文本生成方面,MDM在性能上与使用KV-Cache加速采样的ARM相匹配,同时采样速度快1.4倍。此外,11亿参数的MDM打破了130亿参数Llama-2和1750亿参数GPT-3等更大规模ARM遇到的反向诅咒。
🎯 应用场景
该研究成果可应用于各种自然语言处理任务,如文本生成、机器翻译、文本摘要、问答系统等。通过利用MDM的优势,可以提升这些任务的性能,并解决ARM面临的一些挑战,例如双向推理和适应数据时间变化。此外,该研究还为MDM的未来发展提供了新的方向,例如探索更有效的训练方法和模型架构。
📄 摘要(原文)
Masked diffusion models (MDMs) have shown promise in language modeling, yet their scalability and effectiveness in core language tasks, such as text generation and language understanding, remain underexplored. This paper establishes the first scaling law for MDMs, demonstrating a scaling rate comparable to autoregressive models (ARMs) and a relatively small compute gap. Motivated by their scalability, we train a family of MDMs with up to 1.1 billion (B) parameters to systematically evaluate their performance against ARMs of comparable or larger sizes. Fully leveraging the probabilistic formulation of MDMs, we propose a simple yet effective unsupervised classifier-free guidance that effectively exploits large-scale unpaired data, boosting performance for conditional inference. In language understanding, the 1.1B MDM outperforms the 1.1B TinyLlama model trained on the same data across four of eight zero-shot benchmarks. Notably, it achieves competitive math reasoning ability with the 7B Llama-2 model on the GSM8K dataset. In text generation, MDMs with 16 times more pre-training time offer a flexible trade-off against ARMs with the accelerated sampling technique KV-Cache: MDMs match ARMs in performance while being 1.4 times faster during sampling. Moreover, MDMs address challenging tasks for ARMs by effectively handling bidirectional reasoning and adapting to temporal shifts in data. Notably, a 1.1B MDM breaks the reverse curse encountered by much larger ARMs with significantly more data and computation, such as 13B Llama-2 and 175B GPT-3. Our code is available at https://github.com/ML-GSAI/SMDM.