SpecDec++: Boosting Speculative Decoding via Adaptive Candidate Lengths

📄 arXiv: 2405.19715v3 📥 PDF

作者: Kaixuan Huang, Xudong Guo, Mengdi Wang

分类: cs.CL, cs.AI, cs.LG

发布日期: 2024-05-30 (更新: 2025-07-11)

备注: Accepted to COLM 2025

🔗 代码/项目: GITHUB


💡 一句话要点

SpecDec++通过自适应候选长度提升推测解码效率,加速大语言模型推理。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 推测解码 大语言模型 模型加速 自适应算法 马尔可夫决策过程

📋 核心要点

  1. 现有推测解码方法在选择候选长度K时依赖简单启发式,导致性能并非最优。
  2. SpecDec++将候选长度选择建模为马尔可夫决策过程,并证明最优策略是阈值策略。
  3. SpecDec++通过自适应调整候选长度,在多个数据集上显著提升了推理速度。

📝 摘要(中文)

推测解码通过利用一个更小更快的草稿模型来减少目标大语言模型的推理延迟。其性能取决于超参数K——候选长度,即目标模型在每一轮中验证的候选token的数量。然而,先前的方法通常使用简单的启发式方法来选择K,这可能导致次优的性能。本文研究了候选长度K的选择,并将其形式化为一个马尔可夫决策过程。理论上证明,该马尔可夫决策过程的最优策略采用阈值策略的形式,即当获得拒绝的概率超过阈值时,当前的推测应该停止并进行验证。受此理论的启发,本文提出了SpecDec++,它是推测解码的增强版本,可以动态地自适应地确定候选长度。SpecDec++使用训练好的接受预测头来增强草稿模型,以预测候选token的条件接受概率。当预测到至少一个token被拒绝的概率超过阈值时,SpecDec++将停止当前的推测。SpecDec++应用于llama-2-chat 7B和70B模型对。自适应方法在Alpaca数据集上实现了2.04倍的加速(比基线推测解码提高了7.2%)。在GSM8K和HumanEval数据集上,该方法分别实现了2.26倍的加速(提高了9.4%)和2.23倍的加速(提高了11.1%)。

🔬 方法详解

问题定义:推测解码旨在加速大型语言模型的推理过程,但现有方法中候选长度K的选择通常采用固定的启发式策略,无法根据实际情况动态调整,导致次优的加速效果。如何根据当前生成状态自适应地选择最优的候选长度,是本文要解决的核心问题。

核心思路:论文将候选长度的选择建模为一个马尔可夫决策过程(MDP),并从理论上证明了该MDP的最优策略具有阈值形式。这意味着存在一个阈值,当预测到当前候选序列被拒绝的概率超过该阈值时,就应该停止推测并进行验证。基于此,论文提出了一种自适应调整候选长度的推测解码方法SpecDec++。

技术框架:SpecDec++的核心框架包括以下几个部分:1) 草稿模型:用于生成候选token序列;2) 接受预测头:一个训练好的模型,用于预测候选token被目标模型接受的条件概率;3) 阈值策略:根据接受预测头输出的概率,动态调整候选长度。具体流程是,草稿模型生成候选token序列,接受预测头预测每个token被接受的概率,然后根据阈值策略决定是否停止推测并进行验证。

关键创新:SpecDec++的关键创新在于:1) 将候选长度选择建模为MDP,并从理论上证明了最优策略的阈值形式;2) 提出了一个自适应调整候选长度的推测解码方法,能够根据当前生成状态动态调整候选长度,从而提高加速效果。与现有方法相比,SpecDec++不再依赖固定的候选长度,而是能够根据实际情况进行调整,从而更有效地利用草稿模型和目标模型。

关键设计:SpecDec++的关键设计包括:1) 接受预测头的训练:接受预测头需要准确预测候选token被目标模型接受的概率,因此需要使用大量数据进行训练。论文中使用了目标模型的输出作为训练数据,通过监督学习的方式训练接受预测头。2) 阈值的选择:阈值的选择会影响SpecDec++的性能,阈值过高会导致推测次数减少,加速效果不明显;阈值过低会导致推测错误率增加,影响生成质量。论文中通过实验确定了一个合适的阈值。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

SpecDec++在多个数据集上取得了显著的加速效果。在Alpaca数据集上,SpecDec++实现了2.04倍的加速,相比基线推测解码提高了7.2%。在GSM8K和HumanEval数据集上,SpecDec++分别实现了2.26倍(提高9.4%)和2.23倍(提高11.1%)的加速。这些结果表明,SpecDec++能够有效地提高推测解码的效率。

🎯 应用场景

SpecDec++可广泛应用于需要加速大型语言模型推理的场景,例如在线对话系统、文本生成、机器翻译等。通过自适应调整候选长度,SpecDec++能够更有效地利用计算资源,降低推理延迟,提升用户体验。该研究对于推动大语言模型在实际应用中的部署具有重要意义。

📄 摘要(原文)

Speculative decoding reduces the inference latency of a target large language model via utilizing a smaller and faster draft model. Its performance depends on a hyperparameter K -- the candidate length, i.e., the number of candidate tokens for the target model to verify in each round. However, previous methods often use simple heuristics to choose K, which may result in sub-optimal performance. We study the choice of the candidate length K and formulate it as a Markov Decision Process. We theoretically show that the optimal policy of this Markov decision process takes the form of a threshold policy, i.e., the current speculation should stop and be verified when the probability of getting a rejection exceeds a threshold value. Motivated by this theory, we propose SpecDec++, an enhanced version of speculative decoding that adaptively determines the candidate length on the fly. We augment the draft model with a trained acceptance prediction head to predict the conditional acceptance probability of the candidate tokens. SpecDec++ will stop the current speculation when the predicted probability that at least one token gets rejected exceeds a threshold. We implement SpecDec++ and apply it to the llama-2-chat 7B & 70B model pair. Our adaptive method achieves a 2.04x speedup on the Alpaca dataset (7.2% improvement over the baseline speculative decoding). On the GSM8K and HumanEval datasets, our method achieves a 2.26x speedup (9.4% improvement) and 2.23x speedup (11.1% improvement), respectively. The code of this paper is available at https://github.com/Kaffaljidhmah2/SpecDec_pp.