dParallel: Learnable Parallel Decoding for dLLMs

📄 arXiv: 2509.26488v1 📥 PDF

作者: Zigeng Chen, Gongfan Fang, Xinyin Ma, Ruonan Yu, Xinchao Wang

分类: cs.CL

发布日期: 2025-09-30

备注: Working in progress, code base: https://github.com/czg1225/dParallel

🔗 代码/项目: GITHUB


💡 一句话要点

dParallel:面向扩散大语言模型的可学习并行解码方法,显著降低推理延迟。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 扩散模型 大语言模型 并行解码 知识蒸馏 确定性学习

📋 核心要点

  1. 现有扩散大语言模型(dLLMs)虽然具备并行解码的潜力,但实际应用中仍需大量解码步骤,限制了推理速度。
  2. dParallel通过确定性强制蒸馏,使模型更快地对masked tokens产生高置信度预测,从而减少并行解码所需的步骤。
  3. 实验表明,dParallel在GSM8K和MBPP等基准测试中,显著减少了解码步骤,实现了8.5-10.5倍的加速,同时保持了模型性能。

📝 摘要(中文)

扩散大语言模型(dLLMs)作为自回归生成的一种有前景的替代方案,因其并行token预测和较低的推理延迟而备受研究界关注。然而,它们的并行解码潜力在很大程度上仍未被充分探索,因为现有的开源模型仍然需要接近token长度的解码步骤才能确保性能。为了解决这个问题,我们提出了dParallel,这是一种简单有效的方法,可以释放dLLMs的固有并行性以实现快速采样。我们发现并行解码的关键瓶颈来自于masked tokens的顺序确定性收敛。基于这一洞察,我们引入了我们方法的核心:确定性强制蒸馏,这是一种新颖的训练策略,它蒸馏模型以遵循其原始采样轨迹,同时强制它更快、更并行地在masked tokens上实现高确定性。在各种基准上的大量实验表明,我们的方法可以显著减少解码步骤的数量,同时保持性能。当应用于LLaDA-8B-Instruct模型时,dParallel将GSM8K上的解码步骤从256减少到30,实现了8.5倍的加速,而没有性能下降。在MBPP基准上,它将解码步骤从256减少到24,从而在保持准确性的同时实现了10.5倍的加速。我们的代码可在https://github.com/czg1225/dParallel 获取。

🔬 方法详解

问题定义:论文旨在解决扩散大语言模型(dLLMs)并行解码效率低下的问题。现有dLLMs虽然理论上支持并行生成token,但实际推理时仍需要大量的串行解码步骤,严重限制了其推理速度。其主要痛点在于masked tokens的确定性收敛速度慢,导致需要多次迭代才能得到高质量的生成结果。

核心思路:论文的核心思路是通过“确定性强制蒸馏”来加速masked tokens的确定性收敛。具体而言,通过训练模型,使其在更少的解码步骤内,对masked tokens产生高置信度的预测。这样,模型就能更快地完成并行解码,从而降低推理延迟。

技术框架:dParallel的核心是确定性强制蒸馏训练策略。该策略包括两个关键部分:一是让学生模型(经过dParallel训练的模型)模仿教师模型(原始dLLM)的采样轨迹,保证生成质量;二是引入确定性损失,强制学生模型更快地对masked tokens产生高置信度预测。整体训练流程包括:1. 使用教师模型生成采样轨迹;2. 使用确定性损失训练学生模型,使其逼近教师模型的轨迹,并加速确定性收敛。

关键创新:dParallel的关键创新在于提出了“确定性强制蒸馏”这一训练策略。与传统的蒸馏方法不同,dParallel不仅关注生成结果的相似性,更关注中间过程的确定性收敛速度。通过显式地强制模型更快地产生高置信度预测,从而实现了更高效的并行解码。

关键设计:确定性损失是dParallel的关键设计。具体而言,该损失函数衡量了学生模型在每个解码步骤中,对masked tokens预测概率分布的熵。通过最小化该熵,可以促使模型更快地产生尖锐的概率分布,从而实现高置信度预测。此外,论文还探索了不同的蒸馏策略,例如直接模仿教师模型的logits或概率分布,以进一步提升生成质量。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

dParallel在LLaDA-8B-Instruct模型上进行了实验,结果表明,在GSM8K基准测试中,解码步骤从256减少到30,实现了8.5倍的加速,且性能没有下降。在MBPP基准测试中,解码步骤从256减少到24,实现了10.5倍的加速,同时保持了准确性。这些结果表明dParallel能够有效提升dLLMs的并行解码效率。

🎯 应用场景

dParallel的潜在应用领域包括需要快速文本生成的场景,例如实时对话系统、机器翻译、代码生成等。通过显著降低dLLMs的推理延迟,dParallel可以提升用户体验,并降低计算成本。未来,该方法有望推广到其他类型的生成模型,进一步提升生成效率。

📄 摘要(原文)

Diffusion large language models (dLLMs) have recently drawn considerable attention within the research community as a promising alternative to autoregressive generation, offering parallel token prediction and lower inference latency. Yet, their parallel decoding potential remains largely underexplored, as existing open-source models still require nearly token-length decoding steps to ensure performance. To address this, we introduce dParallel, a simple and effective method that unlocks the inherent parallelism of dLLMs for fast sampling. We identify that the key bottleneck to parallel decoding arises from the sequential certainty convergence for masked tokens. Building on this insight, we introduce the core of our approach: certainty-forcing distillation, a novel training strategy that distills the model to follow its original sampling trajectories while enforcing it to achieve high certainty on masked tokens more rapidly and in parallel. Extensive experiments across various benchmarks demonstrate that our method can dramatically reduce the number of decoding steps while maintaining performance. When applied to the LLaDA-8B-Instruct model, dParallel reduces decoding steps from 256 to 30 on GSM8K, achieving an 8.5x speedup without performance degradation. On the MBPP benchmark, it cuts decoding steps from 256 to 24, resulting in a 10.5x speedup while maintaining accuracy. Our code is available at https://github.com/czg1225/dParallel