Learning Harmonized Representations for Speculative Sampling

📄 arXiv: 2408.15766v3 📥 PDF

作者: Lefan Zhang, Xiaodan Wang, Yanhua Huang, Ruiwen Xu

分类: cs.LG, cs.CL

发布日期: 2024-08-28 (更新: 2025-02-26)

备注: Published as a conference paper at ICLR 2025

🔗 代码/项目: GITHUB


💡 一句话要点

提出HArmonized Speculative Sampling (HASS)以解决LLM推断加速中的上下文不一致问题

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 推测采样 大语言模型 模型加速 上下文对齐 目标蒸馏

📋 核心要点

  1. 现有推测采样方法在训练和解码阶段存在上下文不一致的问题,导致性能瓶颈。
  2. HASS通过学习协调的表示,利用协调的目标蒸馏和上下文对齐来解决上下文不一致问题。
  3. 实验表明,HASS在LLaMA模型上实现了显著的加速,优于现有方法EAGLE-2。

📝 摘要(中文)

推测采样是一种加速大型语言模型(LLM)解码阶段的有前景的方法。最近利用目标LLM的上下文信息(如隐藏状态和KV缓存)的进展显示出显著的实际改进。然而,这些方法存在训练和解码之间上下文不一致的问题。我们还观察到现有推测采样方法中训练和解码目标之间的另一个差异。在这项工作中,我们提出了一种名为HArmonized Speculative Sampling (HASS) 的解决方案,该方案学习协调的表示以解决这些问题。HASS通过协调的目标蒸馏和协调的上下文对齐来加速解码阶段,而无需增加推理开销。在四个LLaMA模型上的实验表明,HASS在三个数据集上的平均加速比为2.81x-4.05x,超过EAGLE-2 8%-20%。代码可在https://github.com/HArmonizedSS/HASS获取。

🔬 方法详解

问题定义:现有推测采样方法在加速LLM解码时,面临训练和解码阶段上下文不一致的问题。具体来说,训练阶段使用的上下文信息与解码阶段实际使用的上下文信息存在差异,导致模型性能下降。此外,训练目标和解码目标也存在差异,进一步加剧了这一问题。

核心思路:HASS的核心思路是学习协调的表示,从而弥合训练和解码阶段的上下文差异。通过协调的目标蒸馏,使小模型学习与大模型一致的目标;通过协调的上下文对齐,使小模型学习与大模型相似的上下文表示。

技术框架:HASS的整体框架包括两个主要模块:协调的目标蒸馏和协调的上下文对齐。协调的目标蒸馏模块利用大模型的输出作为监督信号,训练小模型学习与大模型一致的目标。协调的上下文对齐模块通过最小化小模型和大模型上下文表示之间的距离,使小模型学习与大模型相似的上下文表示。

关键创新:HASS的关键创新在于提出了协调的表示学习方法,通过协调的目标蒸馏和上下文对齐,有效地解决了训练和解码阶段的上下文不一致问题。与现有方法相比,HASS能够更好地利用目标LLM的上下文信息,从而实现更高的加速比。

关键设计:在目标蒸馏方面,HASS使用KL散度作为损失函数,衡量小模型和大模型输出之间的差异。在上下文对齐方面,HASS使用余弦相似度作为距离度量,衡量小模型和大模型上下文表示之间的相似度。此外,HASS还采用了动态调整的权重系数,平衡目标蒸馏和上下文对齐之间的重要性。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,HASS在四个LLaMA模型上实现了显著的加速,平均加速比为2.81x-4.05x,超过了现有方法EAGLE-2 8%-20%。这些结果表明,HASS能够有效地解决训练和解码阶段的上下文不一致问题,从而实现更高的加速比。此外,HASS无需增加推理开销,使其更具实用性。

🎯 应用场景

HASS可应用于各种需要加速LLM解码的场景,例如机器翻译、文本生成、对话系统等。通过提高LLM的推理速度,HASS可以降低计算成本,提高用户体验,并促进LLM在更多实际应用中的部署。该方法尤其适用于资源受限的设备,例如移动设备和嵌入式系统。

📄 摘要(原文)

Speculative sampling is a promising approach to accelerate the decoding stage for Large Language Models (LLMs). Recent advancements that leverage target LLM's contextual information, such as hidden states and KV cache, have shown significant practical improvements. However, these approaches suffer from inconsistent context between training and decoding. We also observe another discrepancy between the training and decoding objectives in existing speculative sampling methods. In this work, we propose a solution named HArmonized Speculative Sampling (HASS) that learns harmonized representations to address these issues. HASS accelerates the decoding stage without adding inference overhead through harmonized objective distillation and harmonized context alignment. Experiments on four LLaMA models demonstrate that HASS achieves 2.81x-4.05x wall-clock time speedup ratio averaging across three datasets, surpassing EAGLE-2 by 8%-20%. The code is available at https://github.com/HArmonizedSS/HASS.