Boosting Lossless Speculative Decoding via Feature Sampling and Partial Alignment Distillation
作者: Lujun Gui, Bin Xiao, Lei Su, Weipeng Chen
分类: cs.CL, cs.LG
发布日期: 2024-08-28
备注: The work was not submitted to AAAI 2025
💡 一句话要点
提出FSPAD,通过特征采样与部分对齐蒸馏提升无损推测解码效率。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 无损推测解码 大语言模型 知识蒸馏 特征采样 部分对齐 模型加速 LLM推理
📋 核心要点
- 现有无损推测解码方法在特征对齐和logit置信度之间存在冲突,限制了性能提升。
- FSPAD通过特征采样获取更准确的目标LLM特征,并采用部分对齐蒸馏解耦特征和logits。
- 实验表明,FSPAD在多种任务和模型上均超越了现有最佳方法,显著提升了推测解码效率。
📝 摘要(中文)
本文提出FSPAD(Feature Sampling and Partial Alignment Distillation for Lossless Speculative Decoding),旨在提升无损推测解码的效率。无损推测解码通过轻量级的草稿模型生成树状候选,再由目标大语言模型并行验证,从而加速目标LLM的推理。现有方法利用特征级别的自回归而非token级别,以简化预测并增强知识蒸馏。FSPAD在现有框架内引入两个简单有效的组件:一是利用token embeddings对目标LLM的高维特征进行采样,解决草稿模型难以获得目标LLM特定token输出的问题;二是引入部分对齐蒸馏,减弱草稿模型中特征与logits之间的联系,以减少训练期间特征对齐与logit置信度之间的冲突。实验结果表明,在Vicuna和LLaMA3-Instruct系列模型上,FSPAD在多轮对话、翻译、摘要、问答、数学推理和检索增强生成等任务中均优于现有最佳方法。
🔬 方法详解
问题定义:论文旨在解决无损推测解码中,草稿模型难以准确预测目标LLM输出的问题。现有方法在特征对齐和logit置信度之间存在冲突,导致草稿模型生成的候选序列质量不高,降低了推测解码的加速效果。
核心思路:论文的核心思路是通过特征采样,使草稿模型能够更好地利用目标LLM的特征信息,同时通过部分对齐蒸馏,解耦特征和logits之间的强关联,从而缓解特征对齐和logit置信度之间的冲突。这样可以提高草稿模型预测的准确性,生成更高质量的候选序列。
技术框架:FSPAD沿用现有的无损推测解码框架,主要包含目标LLM和草稿模型两个部分。首先,利用目标LLM的token embeddings对特征进行采样,然后将采样后的特征输入到草稿模型中进行预测。草稿模型的训练采用部分对齐蒸馏,目标是使草稿模型的预测结果与目标LLM的输出尽可能一致。最后,使用目标LLM并行验证草稿模型生成的候选序列。
关键创新:论文的关键创新在于提出了特征采样和部分对齐蒸馏两种方法。特征采样解决了草稿模型难以直接利用目标LLM特征的问题,部分对齐蒸馏缓解了特征对齐和logit置信度之间的冲突。这两种方法相互配合,共同提升了草稿模型的预测准确性。
关键设计:特征采样使用目标LLM的token embeddings作为采样权重,选择与当前token最相关的特征。部分对齐蒸馏通过调整损失函数,降低草稿模型特征与logits之间的依赖关系,具体实现方式未知。
🖼️ 关键图片
📊 实验亮点
实验结果表明,FSPAD在Vicuna和LLaMA3-Instruct系列模型上,在多轮对话、翻译、摘要、问答、数学推理和检索增强生成等任务中均优于现有最佳方法。具体性能提升数据未知,但总体表现超越了当前最优水平。
🎯 应用场景
该研究成果可广泛应用于需要加速LLM推理的场景,例如在线对话系统、机器翻译、文本摘要、问答系统等。通过提高推理效率,可以降低计算成本,提升用户体验,并促进LLM在资源受限设备上的部署。未来,该方法有望进一步扩展到其他模型架构和任务中。
📄 摘要(原文)
Lossless speculative decoding accelerates target large language model (LLM) inference by employing a lightweight draft model for generating tree-structured candidates, which are subsequently verified in parallel by the target LLM. Currently, effective approaches leverage feature-level rather than token-level autoregression within the draft model to facilitate more straightforward predictions and enhanced knowledge distillation. In this paper, we reassess these approaches and propose FSPAD (Feature Sampling and Partial Alignment Distillation for Lossless Speculative Decoding), which introduces two straightforward and effective components within the existing framework to boost lossless speculative decoding. Firstly, FSPAD utilizes token embeddings to sample features of the target LLM in high-dimensional space before feeding them into the draft model, due to the inherent uncertainty of the features preventing the draft model from obtaining the specific token output by the target LLM. Secondly, FSPAD introduces partial alignment distillation to weaken the draft model's connection between features and logits, aiming to reduce the conflict between feature alignment and logit confidence during training. Our experiments include both greedy and non-greedy decoding on the largest and smallest models from the Vicuna and LLaMA3-Instruct series, as well as tasks in multi-turn conversation, translation, summarization, question answering, mathematical reasoning, and retrieval-augmented generation. The results show that FSPAD outperforms the state-of-the-art method across all the aforementioned tasks and target LLMs.