Gumbel Distillation for Parallel Text Generation
作者: Chi Zhang, Xixi Hu, Bo Liu, Qiang Liu
分类: cs.CL, cs.LG
发布日期: 2026-03-23
备注: ICLR 2026
🔗 代码/项目: GITHUB
💡 一句话要点
提出Gumbel蒸馏,提升并行文本生成模型质量,缩小与自回归模型的差距
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 并行文本生成 知识蒸馏 Gumbel-Max技巧 非自回归模型 语言模型 序列生成 模型加速
📋 核心要点
- 自回归语言模型生成速度慢,并行解码模型虽然速度快,但生成质量不如自回归模型。
- 利用Gumbel-Max技巧,将AR教师模型的知识蒸馏到并行解码器,使其更好地学习token序列的联合分布。
- Gumbel蒸馏是一种模型无关方法,可以与多种并行解码架构集成,并在多个数据集上取得了显著的性能提升。
📝 摘要(中文)
自回归(AR)语言模型因其缓慢的序列生成特性,推动了并行解码方法的发展。然而,这些非自回归模型在建模token序列的复杂联合分布时面临挑战,导致生成质量下降。为了缩小这一性能差距,我们引入了Gumbel蒸馏,一种新颖的蒸馏技术,使并行解码器能够有效地学习这种分布。我们的方法利用Gumbel-Max技巧,从潜在的Gumbel噪声空间创建一个到高性能AR教师模型输出token的确定性映射。作为一种模型无关的技术,Gumbel蒸馏可以无缝集成到各种并行解码架构中,包括MDLM和BD3-LM。在LM1B和OpenWebText上的实验表明,Gumbel蒸馏显著提高了并行语言模型的生成质量,在OpenWebText数据集上训练的MDLM模型,MAUVE得分提高了30.0%,生成困惑度降低了10.5%。代码已开源。
🔬 方法详解
问题定义:并行文本生成模型虽然生成速度快,但由于难以建模复杂的token序列联合分布,导致生成质量不如自回归模型。现有方法在保证生成质量方面存在瓶颈,无法充分发挥并行解码的优势。
核心思路:利用知识蒸馏,将高性能自回归教师模型的知识迁移到并行解码器。具体而言,通过Gumbel-Max技巧,建立从潜在Gumbel噪声空间到教师模型输出token的确定性映射,使得并行解码器能够学习到教师模型的token分布。这样设计的目的是让并行模型模仿自回归模型的生成模式,从而提高生成质量。
技术框架:Gumbel蒸馏框架包含一个自回归教师模型和一个并行解码器学生模型。训练过程中,首先使用Gumbel-Max技巧从Gumbel噪声空间采样,然后通过教师模型得到对应的token序列。接着,并行解码器学习预测这些token序列,从而模仿教师模型的行为。推理阶段,并行解码器直接生成token序列。
关键创新:该方法的核心创新在于利用Gumbel-Max技巧建立Gumbel噪声空间到token序列的确定性映射,从而将自回归教师模型的知识有效地传递给并行解码器。与传统的知识蒸馏方法不同,Gumbel蒸馏不需要对教师模型的输出概率进行软化,而是直接学习token序列的生成过程。
关键设计:Gumbel-Max技巧是关键。具体来说,对于教师模型的每个token,计算其logits,并加上从Gumbel分布中采样的噪声。然后,通过argmax操作选择具有最大值的token。通过调整Gumbel分布的温度参数,可以控制token选择的随机性。损失函数通常采用交叉熵损失,衡量并行解码器的输出与教师模型输出之间的差异。
🖼️ 关键图片
📊 实验亮点
在LM1B和OpenWebText数据集上的实验结果表明,Gumbel蒸馏能够显著提高并行语言模型的生成质量。例如,在OpenWebText数据集上训练的MDLM模型,使用Gumbel蒸馏后,MAUVE得分提高了30.0%,生成困惑度降低了10.5%。这些结果表明,Gumbel蒸馏是一种有效的并行文本生成模型训练方法。
🎯 应用场景
Gumbel蒸馏技术可广泛应用于需要快速文本生成的场景,例如机器翻译、文本摘要、对话系统等。通过提高并行文本生成模型的质量,可以显著提升这些应用的性能和用户体验。该技术还有潜力应用于其他序列生成任务,例如语音合成和音乐生成。
📄 摘要(原文)
The slow, sequential nature of autoregressive (AR) language models has driven the adoption of parallel decoding methods. However, these non-AR models often sacrifice generation quality as they struggle to model the complex joint distribution of token sequences. To narrow this performance gap, we introduce Gumbel Distillation, a novel distillation technique that enables parallel decoders to learn this distribution effectively. Our method leverages the Gumbel-Max trick to create a deterministic mapping from a latent Gumbel noise space to the output tokens of a high-performing AR teacher. As a model-agnostic technique, Gumbel Distillation seamlessly integrates with diverse parallel decoding architectures, including MDLM and BD3-LM. Experiments on LM1B and OpenWebText show that Gumbel Distillation substantially improves the generation quality of parallel language models, achieving a 30.0% improvement in MAUVE score and 10.5% in generative perplexity over MDLM trained on OpenWebText dataset. Code available at https://github.com/hxixixh/gumbel-distill.