Beyond Autoregression: Fast LLMs via Self-Distillation Through Time
作者: Justin Deschenaux, Caglar Gulcehre
分类: cs.LG, cs.CL
发布日期: 2024-10-28 (更新: 2025-02-06)
💡 一句话要点
提出基于时序自蒸馏的快速扩散语言模型,显著提升生成速度与文本质量。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 扩散模型 语言模型 自蒸馏 并行生成 快速推理 文本生成 自然语言理解
📋 核心要点
- 自回归LLM逐token生成导致推理延迟高,限制了其在实时性要求高的场景中的应用。
- 提出基于时序自蒸馏的扩散语言模型,实现多token并行生成,加速推理过程。
- 实验表明,该方法在文本质量和速度上均优于自回归模型,尤其是在LAMBADA基准上。
📝 摘要(中文)
自回归大型语言模型(LLMs)在众多任务中取得了显著成功。然而,自回归建模范式存在一些局限性;例如,当前的自回归LLMs被训练为一次生成一个token,这可能导致明显的延迟。最近的进展表明,搜索和重复采样可以通过在推理期间利用更多的计算资源来提高各种应用(如定理证明、代码生成和对齐)的性能。在本研究中,我们证明了扩散语言模型能够同时生成至少32个token,同时在文本质量和LAMBADA自然语言理解基准测试中超过了自回归模型的性能。这一结果是通过一种用于离散扩散模型的新型蒸馏方法实现的,该方法将推理步骤的数量减少了32-64倍。实际上,在13亿参数规模下,即使没有缓存,扩散模型也能以比采用KV缓存的自回归模型快8倍的速度生成token,并且我们预计通过包含缓存可以进一步改进。此外,我们证明了我们的方法对于高达8.6亿参数的扩散语言模型的有效性。
🔬 方法详解
问题定义:自回归语言模型(AR LLMs)的推理速度受限于其固有的自回归特性,即必须逐个token生成文本。这导致在高并发或对延迟敏感的应用中,AR LLMs的效率较低。现有方法,如KV-caching,虽然可以缓解部分问题,但仍然无法从根本上解决顺序生成带来的瓶颈。
核心思路:该论文的核心思路是利用扩散模型的多token并行生成能力,通过一种新颖的自蒸馏方法,将扩散模型的生成速度提升到可以与甚至超过AR模型的速度。通过将扩散模型训练成可以直接预测多个token,从而避免了AR模型逐个token生成的限制。
技术框架:该方法基于离散扩散模型,并引入了时序自蒸馏技术。整体流程包括:1) 训练一个标准的离散扩散模型;2) 使用自蒸馏方法,将扩散模型的推理步数大幅减少(32-64倍);3) 在推理阶段,扩散模型可以并行生成多个token。该框架的关键在于自蒸馏过程,它允许模型在更少的步骤内生成高质量的文本。
关键创新:该论文的关键创新在于提出了一种针对离散扩散模型的自蒸馏方法,使其能够在保证文本质量的前提下,大幅减少推理步骤。这种方法不同于传统的蒸馏方法,它利用了扩散模型自身的特性,通过时序上的信息传递,实现了更高效的知识迁移。
关键设计:自蒸馏过程的关键在于设计合适的损失函数,以确保蒸馏后的模型能够尽可能地逼近原始模型的输出分布。具体的损失函数可能涉及到KL散度或交叉熵等,用于衡量蒸馏模型和原始模型在token预测上的差异。此外,模型的参数设置,如扩散步骤的数量、噪声schedule等,也会影响最终的生成质量和速度。
🖼️ 关键图片
📊 实验亮点
实验结果表明,该方法在13亿参数规模下,即使没有缓存,扩散模型也能以比采用KV缓存的自回归模型快8倍的速度生成token。此外,在LAMBADA自然语言理解基准测试中,该方法的性能超过了自回归模型,证明了其在文本质量方面的优势。该方法在高达8.6亿参数的扩散语言模型上同样有效。
🎯 应用场景
该研究成果可广泛应用于需要快速文本生成的场景,例如实时对话系统、机器翻译、代码生成等。通过提高LLM的生成速度,可以显著改善用户体验,并降低计算成本。未来,该方法有望进一步扩展到更大规模的LLM,并与其他加速技术(如模型压缩、量化)相结合,实现更高效的文本生成。
📄 摘要(原文)
Autoregressive (AR) Large Language Models (LLMs) have demonstrated significant success across numerous tasks. However, the AR modeling paradigm presents certain limitations; for instance, contemporary autoregressive LLMs are trained to generate one token at a time, which can result in noticeable latency. Recent advances have indicated that search and repeated sampling can enhance performance in various applications, such as theorem proving, code generation, and alignment, by utilizing greater computational resources during inference. In this study, we demonstrate that diffusion language models are capable of generating at least 32 tokens simultaneously, while exceeding the performance of AR models in text quality and on the LAMBADA natural language understanding benchmark. This outcome is achieved through a novel distillation method for discrete diffusion models, which reduces the number of inference steps by a factor of 32-64. Practically, at the 1.3B parameters scale, diffusion models, even without caching, can generate tokens at a rate that is up to 8 times faster than AR models employing KV-caching, and we anticipate further improvements with the inclusion of caching. Moreover, we demonstrate the efficacy of our approach for diffusion language models with up to 860M parameters.