Sparse Logit Sampling: Accelerating Knowledge Distillation in LLMs
作者: Anshumann, Mohd Abbas Zaidi, Akhil Kedia, Jinwoo Ahn, Taehwak Kwon, Kangwook Lee, Haejun Lee, Joohyung Lee
分类: cs.LG, cs.AI, cs.CL
发布日期: 2025-03-21 (更新: 2025-07-24)
备注: Accepted as Oral paper at ACL 2025. Source code is available at https://github.com/akhilkedia/RandomSamplingKD . Anshumann, Mohd Abbas Zaidi and Akhil Kedia have Equal Contribution
💡 一句话要点
提出基于重要性采样的稀疏Logit蒸馏方法,加速LLM知识蒸馏。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 知识蒸馏 大语言模型 重要性采样 模型压缩 稀疏Logit 无偏估计 预训练
📋 核心要点
- 现有稀疏知识蒸馏方法(如Top-K概率缓存)会引入偏差,导致学生模型性能下降和校准不佳。
- 提出基于重要性采样的随机采样知识蒸馏,提供无偏估计,保留梯度,并显著减少存储的logits数量。
- 实验表明,该方法在加速学生模型训练的同时,保持了与完整蒸馏相当的性能,且开销很小。
📝 摘要(中文)
知识蒸馏是一种经济高效的技术,可用于在大语言模型中提炼知识,前提是教师模型的输出logits可以预先计算和缓存。然而,成功地将其应用于预训练在很大程度上仍未被探索。本文证明,诸如缓存Top-K概率等稀疏知识蒸馏的朴素方法,虽然直观,但会给学生模型提供教师概率分布的有偏估计,导致次优的性能和校准。我们提出了一种基于重要性采样的方法“随机采样知识蒸馏”,该方法提供无偏估计,在期望中保留梯度,并且需要存储显著稀疏的logits。我们的方法能够以边际开销(<10%)相比于基于交叉熵的训练更快地训练学生模型,同时在300M到3B的一系列模型尺寸上保持与完整蒸馏相比具有竞争力的性能。
🔬 方法详解
问题定义:论文旨在解决大语言模型(LLM)知识蒸馏中,教师模型输出logits存储开销大和稀疏蒸馏方法引入偏差的问题。传统的知识蒸馏需要存储完整的教师模型logits,这对于大型LLM来说是不现实的。而简单的稀疏化方法,如Top-K采样,会引入偏差,导致学生模型性能下降。
核心思路:论文的核心思路是使用重要性采样来构建教师模型概率分布的无偏估计。通过对教师模型的logits进行随机采样,并根据采样概率进行加权,可以得到一个无偏的概率分布估计,从而避免了传统稀疏化方法引入的偏差。
技术框架:该方法主要包含以下几个步骤:1) 预先计算并缓存教师模型的logits;2) 在训练过程中,对教师模型的logits进行随机采样;3) 根据采样概率计算重要性权重;4) 使用加权后的logits计算蒸馏损失,并更新学生模型的参数。
关键创新:该方法最重要的创新点在于使用重要性采样来构建教师模型概率分布的无偏估计。与传统的稀疏化方法相比,该方法可以避免引入偏差,从而提高学生模型的性能。此外,该方法只需要存储稀疏的logits,从而降低了存储开销。
关键设计:论文使用均匀分布作为采样分布,并根据采样概率计算重要性权重。蒸馏损失函数可以使用KL散度或交叉熵损失。实验中,作者使用了不同大小的模型(300M到3B),并比较了该方法与完整蒸馏和Top-K采样等基线方法的性能。
🖼️ 关键图片
📊 实验亮点
实验结果表明,提出的随机采样知识蒸馏方法在300M到3B的模型尺寸范围内,能够以小于10%的额外开销,实现与完整蒸馏相当的性能,同时显著降低了存储需求。与Top-K采样等基线方法相比,该方法能够提供更准确的教师模型概率分布估计,从而提高学生模型的性能。
🎯 应用场景
该研究成果可应用于各种需要知识蒸馏的大语言模型场景,例如模型压缩、加速推理、迁移学习等。通过降低存储开销和提高蒸馏效率,该方法可以帮助研究人员和工程师更有效地训练和部署LLM,尤其是在资源受限的环境下。该方法也有助于预训练LLM,降低预训练的计算成本。
📄 摘要(原文)
Knowledge distillation can be a cost-effective technique to distill knowledge in Large Language Models, if the teacher output logits can be pre-computed and cached. However, successfully applying this to pre-training remains largely unexplored. In this work, we prove that naive approaches for sparse knowledge distillation such as caching Top-K probabilities, while intuitive, provide biased estimates of teacher probability distribution to the student, resulting in suboptimal performance and calibration. We propose an importance-sampling-based method `Random Sampling Knowledge Distillation', which provides unbiased estimates, preserves the gradient in expectation, and requires storing significantly sparser logits. Our method enables faster training of student models with marginal overhead (<10%) compared to cross-entropy based training, while maintaining competitive performance compared to full distillation, across a range of model sizes from 300M to 3B.