Judge Decoding: Faster Speculative Sampling Requires Going Beyond Model Alignment

📄 arXiv: 2501.19309v1 📥 PDF

作者: Gregor Bachmann, Sotiris Anagnostidis, Albert Pumarola, Markos Georgopoulos, Artsiom Sanakoyeu, Yuming Du, Edgar Schönfeld, Ali Thabet, Jonas Kohler

分类: cs.LG, cs.CL

发布日期: 2025-01-31


💡 一句话要点

提出Judge Decoding,通过训练判别模块显著加速LLM推断,突破模型对齐限制。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 推测解码 大语言模型 LLM加速 模型推理 知识蒸馏 模型对齐 判别模型

📋 核心要点

  1. 现有推测解码方法依赖草稿模型与目标模型的对齐,导致大量高质量token被拒绝,限制了加速潜力。
  2. 论文提出Judge Decoding,训练一个判别模块来判断延续的正确性,即使草稿模型与目标模型未完全对齐。
  3. 实验表明,Judge Decoding在Llama-3.1系列上实现了显著加速,最高可达9倍,且保持了模型质量。

📝 摘要(中文)

大型语言模型(LLM)的性能与其模型规模密切相关,导致网络越来越大,推理速度也越来越慢。推测解码是一种加速自回归生成的技术,它利用快速的草稿模型来提出候选token,然后基于目标模型的可能性并行验证这些token。虽然这种方法保证了重现目标输出,但它会带来相当大的损失:许多高质量的草稿token被拒绝,即使它们代表了客观上有效的延续。事实上,即使是像GPT-4o这样强大的草稿模型,以及人类文本,也无法在标准验证方案下实现高接受率。这严重限制了当前推测解码方法的速度提升潜力,因为仅仅依靠草稿和目标模型的对齐,早期拒绝的可能性就变得非常大。因此,我们提出了以下问题:我们能否调整验证方法,以识别正确但未对齐的回复?为此,我们从LLM-as-a-judge框架中汲取灵感,该框架表明LLM能够以多种方式对答案进行评分。我们精心设计了一个数据集,通过在嵌入之上训练一个紧凑的模块来产生当前延续的“判断”,从而在目标模型中引发相同的功能。我们在Llama-3.1系列上展示了我们的策略,其中我们的8b/405B-Judge在保持其在大量基准测试中的质量的同时,实现了相对于Llama-405B的9倍加速。即使在优化的推理框架中,这些优势仍然存在,其中我们的方法在2个和8个H100上分别达到了8B/70B-Judge的141个token/s和8B/405B的129个token/s。

🔬 方法详解

问题定义:推测解码旨在加速LLM的自回归生成过程,但现有方法严重依赖草稿模型和目标模型的对齐程度。即使草稿模型生成的token在语义上是合理的,但如果与目标模型的预测不完全一致,也会被拒绝,导致效率降低。现有方法的痛点在于过于严格的对齐要求,限制了加速潜力。

核心思路:论文的核心思路是引入一个“判断”模块,该模块能够评估当前延续的质量,而不仅仅是依赖草稿模型和目标模型的对齐。这个判断模块可以识别那些虽然与目标模型预测不完全一致,但仍然是正确和有用的token。通过这种方式,可以提高token的接受率,从而加速生成过程。

技术框架:整体框架包括一个草稿模型、一个目标模型和一个判断模块。草稿模型快速生成候选token,目标模型用于验证这些token,而判断模块则对当前延续进行评估。判断模块的输出用于调整token的接受策略。具体流程如下:1) 草稿模型生成多个候选token;2) 目标模型计算这些token的概率;3) 判断模块基于当前上下文和候选token生成一个“判断”;4) 基于目标模型的概率和判断模块的输出,决定是否接受这些token。

关键创新:最重要的技术创新点在于引入了判断模块,该模块能够超越简单的模型对齐,识别语义上合理的延续。与现有方法相比,Judge Decoding不再仅仅依赖草稿模型和目标模型的概率分布是否一致,而是更加关注延续的整体质量。

关键设计:论文设计了一个专门的数据集来训练判断模块,该数据集包含各种不同质量的延续,以及对应的“判断”标签。判断模块是一个小型神经网络,它接收目标模型的嵌入作为输入,并输出一个标量值,表示对当前延续的判断。损失函数被设计为鼓励判断模块输出与数据集中的标签一致的值。具体来说,使用了二元交叉熵损失函数,其中正例是高质量的延续,负例是低质量的延续。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,Judge Decoding在Llama-3.1系列上实现了显著的加速。例如,使用8b/405B-Judge配置,相对于Llama-405B,实现了9倍的加速,同时保持了模型在各种基准测试中的质量。在优化的推理框架中,8B/70B-Judge达到了141 tokens/s,而8B/405B达到了129 tokens/s,分别在2个和8个H100 GPU上运行。

🎯 应用场景

Judge Decoding可广泛应用于需要快速LLM推理的场景,如实时对话系统、内容生成、代码补全等。该方法降低了对更大模型的依赖,使得在资源受限的环境中使用高性能LLM成为可能。未来,该技术有望进一步提升LLM在边缘设备上的应用潜力。

📄 摘要(原文)

The performance of large language models (LLMs) is closely linked to their underlying size, leading to ever-growing networks and hence slower inference. Speculative decoding has been proposed as a technique to accelerate autoregressive generation, leveraging a fast draft model to propose candidate tokens, which are then verified in parallel based on their likelihood under the target model. While this approach guarantees to reproduce the target output, it incurs a substantial penalty: many high-quality draft tokens are rejected, even when they represent objectively valid continuations. Indeed, we show that even powerful draft models such as GPT-4o, as well as human text cannot achieve high acceptance rates under the standard verification scheme. This severely limits the speedup potential of current speculative decoding methods, as an early rejection becomes overwhelmingly likely when solely relying on alignment of draft and target. We thus ask the following question: Can we adapt verification to recognize correct, but non-aligned replies? To this end, we draw inspiration from the LLM-as-a-judge framework, which demonstrated that LLMs are able to rate answers in a versatile way. We carefully design a dataset to elicit the same capability in the target model by training a compact module on top of the embeddings to produce ``judgements" of the current continuation. We showcase our strategy on the Llama-3.1 family, where our 8b/405B-Judge achieves a speedup of 9x over Llama-405B, while maintaining its quality on a large range of benchmarks. These benefits remain present even in optimized inference frameworks, where our method reaches up to 141 tokens/s for 8B/70B-Judge and 129 tokens/s for 8B/405B on 2 and 8 H100s respectively.