Jakiro: Boosting Speculative Decoding with Decoupled Multi-Head via MoE
作者: Haiduo Huang, Fuwei Yang, Zhenhua Liu, Yixing Xu, Jinze Li, Yang Liu, Xuanwu Yin, Dong Li, Pengju Ren, Emad Barsoum
分类: cs.CL, cs.AI, cs.LG
发布日期: 2025-02-10
🔗 代码/项目: GITHUB
💡 一句话要点
Jakiro:利用MoE解耦多头注意力机制,加速推测解码并提升精度。
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 推测解码 混合专家模型 MoE 并行解码 对比学习 语言模型加速 模型推理
📋 核心要点
- 现有推测解码方法中,同一解码步骤的候选token由相同表示生成,限制了多样性,影响解码效率。
- Jakiro利用MoE,让独立专家生成多样化预测,有效解耦候选token间的相关性,提升预测准确率。
- 结合自回归与并行解码,并引入对比机制,Jakiro在多种模型上验证了有效性,达到SOTA性能。
📝 摘要(中文)
推测解码(SD)通过使用较小的draft模型预测多个token,然后由较大的target模型并行验证,从而加速大型语言模型的推理。然而,draft模型容量有限,通常需要基于树的采样来提高预测精度,即在每一步生成多个候选。我们发现这种方法的一个关键限制:同一步骤的候选来自相同的表示,限制了多样性并降低了整体效果。为了解决这个问题,我们提出了Jakiro,利用混合专家(MoE),其中独立的专家生成不同的预测,有效地解耦了候选之间的相关性。此外,我们引入了一种混合推理策略,将自回归解码用于初始token,将并行解码用于后续阶段,并通过特征中的对比机制来增强后者,以提高准确性。我们的方法显著提高了预测精度,并实现了更高的推理加速。在各种模型上的大量实验验证了我们方法的有效性和鲁棒性,确立了推测解码领域新的SOTA。我们的代码可在https://github.com/haiduo/Jakiro获得。
🔬 方法详解
问题定义:推测解码旨在加速大型语言模型的推理过程。现有的基于树搜索的推测解码方法,由于在同一解码步骤中,所有候选token都来源于相同的模型表示,导致候选token的多样性不足,限制了推测解码的效率和准确性。
核心思路:Jakiro的核心思路是通过引入混合专家模型(MoE),让不同的专家独立生成候选token,从而增加候选token的多样性,降低候选token之间的相关性。这样,target模型在验证时,能够更容易地找到正确的token,提高推测解码的效率。
技术框架:Jakiro的整体框架包含以下几个主要模块:1) MoE Draft Model: 使用MoE作为draft模型,每个expert独立生成候选token。2) Hybrid Inference Strategy: 采用混合推理策略,初始token使用自回归解码,后续token使用并行解码。3) Contrastive Mechanism: 在并行解码阶段,引入对比机制,通过对比不同候选token的特征表示,进一步提高预测准确性。
关键创新:Jakiro的关键创新在于使用MoE来解耦候选token之间的相关性。传统的推测解码方法使用单一的draft模型,导致候选token高度相关。而Jakiro通过MoE,让不同的专家独立生成候选token,从而增加了候选token的多样性,提高了推测解码的效率。
关键设计:Jakiro的关键设计包括:1) MoE结构: 选择了合适的MoE结构,包括专家数量、路由机制等,以保证MoE的性能。2) 对比损失函数: 设计了合适的对比损失函数,用于训练对比机制,提高预测准确性。3) 混合推理策略: 平衡了自回归解码和并行解码的比例,以达到最佳的推理速度和准确性。
🖼️ 关键图片
📊 实验亮点
实验结果表明,Jakiro在多种模型上均取得了显著的性能提升,相较于现有SOTA方法,在推理速度和准确率上均有提升。具体数据需要在论文中查找。该方法在不同模型和数据集上的鲁棒性也得到了验证。
🎯 应用场景
Jakiro可广泛应用于各种需要加速大型语言模型推理的场景,例如:在线对话系统、机器翻译、文本生成等。通过提高推理速度,可以降低延迟,提升用户体验,并降低部署成本。该方法在资源受限的边缘设备上部署大型语言模型具有重要意义。
📄 摘要(原文)
Speculative decoding (SD) accelerates large language model inference by using a smaller draft model to predict multiple tokens, which are then verified in parallel by the larger target model. However, the limited capacity of the draft model often necessitates tree-based sampling to improve prediction accuracy, where multiple candidates are generated at each step. We identify a key limitation in this approach: the candidates at the same step are derived from the same representation, limiting diversity and reducing overall effectiveness. To address this, we propose Jakiro, leveraging Mixture of Experts (MoE), where independent experts generate diverse predictions, effectively decoupling correlations among candidates. Furthermore, we introduce a hybrid inference strategy, combining autoregressive decoding for initial tokens with parallel decoding for subsequent stages, and enhance the latter with contrastive mechanism in features to improve accuracy. Our method significantly boosts prediction accuracy and achieves higher inference speedups. Extensive experiments across diverse models validate the effectiveness and robustness of our approach, establishing a new SOTA in speculative decoding. Our codes are available at https://github.com/haiduo/Jakiro.