Optimized Multi-Token Joint Decoding with Auxiliary Model for LLM Inference

📄 arXiv: 2407.09722v4 📥 PDF

作者: Zongyue Qin, Ziniu Hu, Zifan He, Neha Prakriya, Jason Cong, Yizhou Sun

分类: cs.CL, cs.LG

发布日期: 2024-07-12 (更新: 2025-04-10)

期刊: ICLR 2025


💡 一句话要点

提出MTAD框架,加速LLM多token联合解码,提升推理速度和效果。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 多token联合解码 大型语言模型 推理加速 辅助模型 推测解码 低功耗 模型优化

📋 核心要点

  1. 现有LLM推理效率低,单token生成方式耗时耗力,推测解码虽加速但未提升效果。
  2. MTAD框架利用辅助模型近似联合分布,加速多token联合解码,并通过验证机制保证准确性。
  3. 实验表明,MTAD降低困惑度,提升下游任务性能,加速推理并降低能耗,优于传统推测解码。

📝 摘要(中文)

大型语言模型(LLMs)在各种任务中取得了显著成功,但由于每个解码步骤的单token生成,其推理过程受到大量时间和能源需求的阻碍。先前的推测解码方法通过每步生成多个token来缓解这些低效率,但每个token仍然由其单token分布生成,从而提高了速度但没有提高效果。相反,我们的工作同时提高了推理速度并提高了输出效果。我们考虑多token联合解码(MTJD),它在每次迭代中从多个token的联合分布生成多个token,理论上降低了困惑度并提高了任务性能。然而,MTJD受到从多个token的联合分布中采样的高成本的影响。受推测解码的启发,我们引入了多token辅助解码(MTAD),这是一种旨在加速MTJD的新颖框架。MTAD利用较小的辅助模型来近似较大模型的联合分布,结合验证机制,不仅确保了这种近似的准确性,而且提高了优于传统推测解码的解码效率。从理论上讲,我们证明了MTAD以有界误差逼近精确的MTJD。使用Llama-2和OPT模型(参数范围从13B到70B)在各种任务中进行的实证评估表明,与标准单token采样相比,MTAD降低了21.2%的困惑度并提高了下游性能。此外,MTAD实现了1.42倍的加速,并且比传统的推测解码方法消耗的能量少1.54倍。这些结果突出了MTAD使多token联合解码既有效又高效的能力,从而促进了LLM的更可持续和高性能部署。

🔬 方法详解

问题定义:现有大型语言模型推理过程中,单token生成方式效率低下,导致时间和能源消耗巨大。推测解码等方法虽然可以加速推理,但仍然基于单token分布生成,无法从根本上提升生成质量。多token联合解码(MTJD)理论上可以降低困惑度并提升性能,但直接从多token联合分布中采样计算成本过高。

核心思路:论文的核心思路是利用一个较小的辅助模型来近似大型模型的联合分布,从而加速多token联合解码过程。类似于推测解码,但不同之处在于,MTAD是近似多token的联合分布,而不是单个token的分布。通过引入验证机制,确保辅助模型近似的准确性,避免引入过多的误差。

技术框架:MTAD框架包含以下主要模块:1)辅助模型:用于近似大型模型的联合分布,快速生成多个候选token。2)验证机制:验证辅助模型生成的token序列是否被大型模型接受。3)主模型:大型语言模型,用于验证和生成最终的token序列。整体流程是,辅助模型生成多个token,然后由主模型验证,如果验证通过,则直接采用这些token,否则,主模型生成新的token。

关键创新:MTAD的关键创新在于使用辅助模型近似多token的联合分布,并结合验证机制,实现了高效且准确的多token联合解码。与传统的推测解码相比,MTAD不仅加速了推理过程,还提升了生成质量。此外,理论分析证明了MTAD能够以有界误差逼近精确的MTJD。

关键设计:辅助模型的选择需要权衡计算成本和近似精度。验证机制的设计需要考虑验证的效率和准确性。论文中可能涉及到一些关键的参数设置,例如辅助模型的大小、验证的阈值等,这些参数会影响MTAD的性能。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,MTAD在Llama-2和OPT模型上,参数范围从13B到70B,与标准单token采样相比,困惑度降低了21.2%,下游任务性能得到提升。同时,MTAD实现了1.42倍的加速,并且比传统的推测解码方法消耗的能量少1.54倍。这些数据表明MTAD在提升效率和降低能耗方面具有显著优势。

🎯 应用场景

MTAD框架可应用于各种需要高性能和低能耗的大型语言模型部署场景,例如云端推理服务、边缘设备部署等。通过提升推理速度和效果,MTAD可以降低LLM的使用成本,并促进其在更多领域的应用,例如智能客服、内容生成、机器翻译等。

📄 摘要(原文)

Large language models (LLMs) have achieved remarkable success across diverse tasks, yet their inference processes are hindered by substantial time and energy demands due to single-token generation at each decoding step. While previous methods such as speculative decoding mitigate these inefficiencies by producing multiple tokens per step, each token is still generated by its single-token distribution, thereby enhancing speed without improving effectiveness. In contrast, our work simultaneously enhances inference speed and improves the output effectiveness. We consider multi-token joint decoding (MTJD), which generates multiple tokens from their joint distribution at each iteration, theoretically reducing perplexity and enhancing task performance. However, MTJD suffers from the high cost of sampling from the joint distribution of multiple tokens. Inspired by speculative decoding, we introduce multi-token assisted decoding (MTAD), a novel framework designed to accelerate MTJD. MTAD leverages a smaller auxiliary model to approximate the joint distribution of a larger model, incorporating a verification mechanism that not only ensures the accuracy of this approximation, but also improves the decoding efficiency over conventional speculative decoding. Theoretically, we demonstrate that MTAD closely approximates exact MTJD with bounded error. Empirical evaluations using Llama-2 and OPT models ranging from 13B to 70B parameters across various tasks reveal that MTAD reduces perplexity by 21.2% and improves downstream performance compared to standard single-token sampling. Furthermore, MTAD achieves a 1.42x speed-up and consumes 1.54x less energy than conventional speculative decoding methods. These results highlight MTAD's ability to make multi-token joint decoding both effective and efficient, promoting more sustainable and high-performance deployment of LLMs.