RASD: Retrieval-Augmented Speculative Decoding

📄 arXiv: 2503.03434v1 📥 PDF

作者: Guofeng Quan, Wenfeng Feng, Chuzhan Hao, Guochao Jiang, Yuewei Zhang, Hao Wang

分类: cs.CL, cs.AI

发布日期: 2025-03-05


💡 一句话要点

提出RASD:检索增强的推测解码加速LLM推理,提升领域外泛化性。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 推测解码 检索增强 大型语言模型 推理加速 树剪枝

📋 核心要点

  1. 现有推测解码方法在领域外场景表现不佳,且草稿阶段耗时限制了加速效果。
  2. RASD通过检索增强草稿生成,利用树剪枝和树融合构建更优的验证树。
  3. 实验表明RASD在多种任务上实现了SOTA推理加速,并具有良好的可扩展性。

📝 摘要(中文)

本文提出了一种检索增强的推测解码方法RASD,旨在加速大型语言模型(LLMs)的推理过程。现有的推测解码方法依赖于轻量级的草稿模型或额外的模型结构来生成草稿token并从数据库中检索上下文。然而,由于草稿模型规模小且训练数据有限,基于模型的推测解码在领域外场景中效果不佳。此外,草稿阶段的时间成本限制了验证步骤中接受长度的上限,从而限制了整体效率。RASD采用检索方法来增强基于模型的推测解码,并引入了树剪枝和树融合技术。具体而言,我们开发了一种基于草稿模型概率分布的剪枝方法来构建最优检索树。其次,我们采用最长前缀匹配算法将草稿模型生成的树与检索树合并,从而形成用于验证的统一树。实验结果表明,RASD在DocQA、摘要、代码和领域内QA等任务中实现了最先进的推理加速效果。此外,RASD具有很强的可扩展性,可以无缝地与各种推测解码方法集成,包括基于生成和基于检索的方法。

🔬 方法详解

问题定义:现有推测解码方法,如基于轻量级模型或额外结构的草稿生成,在领域外数据上表现不佳。同时,草稿生成阶段的时间开销限制了后续验证阶段可接受的token数量,最终影响整体的推理加速效果。因此,如何提升草稿生成质量,同时降低其时间开销,是本文要解决的核心问题。

核心思路:RASD的核心思路是利用检索增强来提升草稿生成的质量。通过从外部知识库中检索相关信息,可以弥补草稿模型自身知识的不足,从而生成更准确、更长的草稿序列。同时,通过优化检索过程,可以降低检索带来的额外时间开销。

技术框架:RASD的整体框架包含以下几个主要阶段:1) 使用轻量级草稿模型生成初始草稿树;2) 基于草稿模型的概率分布,对初始草稿树进行剪枝,构建最优检索树;3) 从外部知识库中检索与当前上下文相关的候选token,构建检索树;4) 使用最长前缀匹配算法,将草稿模型生成的树与检索树进行融合,生成最终的验证树;5) 使用目标模型对验证树进行验证,接受或拒绝草稿token。

关键创新:RASD的关键创新在于:1) 提出了一种基于草稿模型概率分布的树剪枝方法,用于构建最优检索树,降低检索范围,提升检索效率;2) 提出了一种基于最长前缀匹配的树融合算法,将草稿模型生成的树与检索树进行有效融合,充分利用了两种信息的优势。

关键设计:在树剪枝阶段,使用草稿模型预测的token概率作为剪枝的依据,只保留概率较高的分支,从而减少检索的候选token数量。在树融合阶段,使用最长前缀匹配算法,优先保留草稿模型生成的token序列,只有在草稿模型无法生成的情况下,才使用检索得到的token序列。具体参数设置未知。

🖼️ 关键图片

fig_0
fig_1

📊 实验亮点

实验结果表明,RASD在DocQA、Summary、Code和In-Domain QA等任务上取得了SOTA的推理加速效果。具体性能数据未知,但论文强调RASD可以无缝集成到现有的推测解码方法中,并具有很强的可扩展性。RASD在领域外数据上的表现优于传统的推测解码方法,证明了其检索增强策略的有效性。

🎯 应用场景

RASD可应用于各种需要加速LLM推理的场景,例如智能客服、机器翻译、代码生成、文档问答等。通过提升推理速度,可以降低计算成本,提高用户体验。该方法尤其适用于领域外数据,可以提升LLM在实际应用中的泛化能力。未来,RASD可以进一步扩展到更多模态的数据,例如图像、音频等。

📄 摘要(原文)

Speculative decoding accelerates inference in large language models (LLMs) by generating draft tokens for target model verification. Current approaches for obtaining draft tokens rely on lightweight draft models or additional model structures to generate draft tokens and retrieve context from databases. Due to the draft model's small size and limited training data, model-based speculative decoding frequently becomes less effective in out-of-domain scenarios. Additionally, the time cost of the drafting phase results in a low upper limit on acceptance length during the verification step, limiting overall efficiency. This paper proposes RASD (Retrieval-Augmented Speculative Decoding), which adopts retrieval methods to enhance model-based speculative decoding. We introduce tree pruning and tree fusion to achieve this. Specifically, we develop a pruning method based on the draft model's probability distribution to construct the optimal retrieval tree. Second, we employ the longest prefix matching algorithm to merge the tree generated by the draft model with the retrieval tree, resulting in a unified tree for verification. Experimental results demonstrate that RASD achieves state-of-the-art inference acceleration across tasks such as DocQA, Summary, Code, and In-Domain QA. Moreover, RASD exhibits strong scalability, seamlessly integrating with various speculative decoding approaches, including both generation-based and retrieval-based methods.