SD$^2$: Self-Distilled Sparse Drafters
作者: Mike Lasby, Nish Sinnadurai, Valavan Manohararajah, Sean Lie, Yani Ioannou, Vithursan Thangarasa
分类: cs.CL, cs.AI
发布日期: 2025-04-10 (更新: 2025-05-31)
备注: 24 pages
💡 一句话要点
提出SD$^2$,通过自蒸馏稀疏化草稿模型提升LLM推断效率,尤其在通用辅助生成场景下。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 推测解码 模型压缩 自蒸馏 稀疏化 语言模型 模型加速 通用辅助生成
📋 核心要点
- 现有推断加速方法如推测解码依赖草稿模型,但草稿模型的效率和对齐是挑战。
- SD$^2$利用自蒸馏和细粒度稀疏化,生成高效且与目标模型对齐的草稿模型。
- 实验表明,SD$^2$显著提升了token接受率并降低了计算量,优于传统剪枝方法。
📝 摘要(中文)
本文提出了一种名为自蒸馏稀疏草稿模型(SD$^2$)的新方法,旨在利用自数据蒸馏和细粒度权重稀疏性来生成高效且对齐良好的草稿模型,从而加速大型语言模型(LLM)的推断过程。SD$^2$系统性地提高了草稿token的接受率,并显著减少了乘加运算(MACs),即使在通用辅助生成(UAG)设置下,即草稿模型和目标模型来自不同的模型家族,也能有效工作。在Llama-3.1-70B目标模型上,SD$^2$相比于层剪枝草稿模型,平均接受长度(MAL)提高了1.59倍,与密集草稿模型相比,MACs降低了43.87%,MAL降低了8.36%。实验表明,我们的1.5B和3B非结构化稀疏草稿模型在端到端延迟改进方面优于密集和层剪枝模型,突出了稀疏感知微调和压缩策略在提高LLM推断效率同时保持与目标模型对齐的潜力。
🔬 方法详解
问题定义:现有推测解码方法依赖于草稿模型生成候选token序列,然后由目标模型验证。草稿模型的质量直接影响加速效果。然而,训练高效且与目标模型对齐的草稿模型是一个挑战。传统的层剪枝方法虽然可以减少计算量,但往往会牺牲模型性能,导致token接受率下降。此外,通用辅助生成(UAG)场景下,草稿模型和目标模型来自不同架构,对齐难度更大。
核心思路:SD$^2$的核心思路是通过自蒸馏和细粒度权重稀疏化,在不显著降低模型性能的前提下,大幅减少草稿模型的计算量。自蒸馏保证了草稿模型与目标模型的对齐,而稀疏化则降低了计算复杂度。通过sparsity-aware fine-tuning,模型能够在稀疏结构下保持甚至提升性能。
技术框架:SD$^2$包含两个主要阶段:自蒸馏和稀疏化微调。首先,使用目标模型作为教师模型,对草稿模型进行自蒸馏,使其学习目标模型的输出分布。然后,对草稿模型进行权重稀疏化,去除不重要的连接。最后,进行稀疏感知微调,在稀疏结构下进一步优化模型参数,使其更好地适应目标模型。
关键创新:SD$^2$的关键创新在于结合了自蒸馏和细粒度权重稀疏化,并提出了sparsity-aware fine-tuning策略。与传统的层剪枝方法相比,SD$^2$能够更精细地控制模型的计算量,并在保持模型性能的同时,实现更高的压缩率。此外,SD$^2$在UAG场景下也表现出色,表明其具有良好的泛化能力。
关键设计:SD$^2$使用了非结构化稀疏化,即随机地去除权重,而不是去除整个层或通道。稀疏率是一个重要的参数,需要根据目标模型的性能和计算资源进行调整。自蒸馏过程中,使用了KL散度损失函数来衡量草稿模型和目标模型输出分布的差异。Sparsity-aware fine-tuning使用了特殊的优化器,能够处理稀疏矩阵的梯度更新。
🖼️ 关键图片
📊 实验亮点
在Llama-3.1-70B目标模型上,SD$^2$相比于层剪枝草稿模型,平均接受长度(MAL)提高了1.59倍。与密集草稿模型相比,MACs降低了43.87%,而MAL仅降低了8.36%。1.5B和3B的SD$^2$模型在端到端延迟改进方面优于密集和层剪枝模型,证明了该方法的有效性。
🎯 应用场景
SD$^2$可应用于各种需要加速LLM推断的场景,例如在线对话系统、机器翻译、代码生成等。通过降低计算成本和延迟,SD$^2$可以使LLM在资源受限的设备上运行,并提高用户体验。未来,SD$^2$可以与其他模型压缩技术相结合,进一步提升LLM的效率。
📄 摘要(原文)
Speculative decoding is a powerful technique for reducing the latency of Large Language Models (LLMs), offering a fault-tolerant framework that enables the use of highly compressed draft models. In this work, we introduce Self-Distilled Sparse Drafters (SD$^2$), a novel methodology that leverages self-data distillation and fine-grained weight sparsity to produce highly efficient and well-aligned draft models. SD$^2$ systematically enhances draft token acceptance rates while significantly reducing Multiply-Accumulate operations (MACs), even in the Universal Assisted Generation (UAG) setting, where draft and target models originate from different model families. On a Llama-3.1-70B target model, SD$^2$ provides a 1.59$\times$ higher Mean Accepted Length (MAL) compared to layer-pruned draft models and reduces MACs by over 43.87% with a 8.36% reduction in MAL compared to a dense draft models. Our 1.5B and 3B unstructured sparse drafters outperform both dense and layer-pruned models in terms of end-to-end latency improvements; highlighting the potential of sparsity-aware fine-tuning and compression strategies to improve LLM inference efficiency while maintaining alignment with target models.