Self-Data Distillation for Recovering Quality in Pruned Large Language Models
作者: Vithursan Thangarasa, Ganesh Venkatesh, Mike Lasby, Nish Sinnadurai, Sean Lie
分类: cs.LG, cs.CL
发布日期: 2024-10-13 (更新: 2025-05-10)
备注: Accepted to MLSys 2025. Main paper: 14 pp., 4 figs., 6 tabs.; Supplementary: 5 pp
💡 一句话要点
提出自数据蒸馏微调方法,恢复剪枝大语言模型中的质量损失。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 大语言模型 模型剪枝 自数据蒸馏 知识蒸馏 灾难性遗忘 模型压缩 推测解码
📋 核心要点
- 大规模语言模型剪枝后,模型质量显著下降,尤其在多步推理任务中,监督微调虽能恢复性能,但易导致灾难性遗忘。
- 论文提出自数据蒸馏微调方法,利用原始未剪枝模型生成蒸馏数据集,保留语义信息并减轻灾难性遗忘,从而提升剪枝模型的性能。
- 实验表明,自数据蒸馏微调在HuggingFace OpenLLM Leaderboard v1上优于标准SFT,平均准确率提升高达8%,并能有效减少FLOPs。
📝 摘要(中文)
大型语言模型在自然语言处理领域取得了显著进展,但其部署需要大量的计算和内存资源。随着模型规模的扩大,压缩技术对于平衡模型质量和计算效率至关重要。结构化剪枝是一种很有前途的降低模型复杂度的策略,它移除模型中不太重要的组件。然而,一次性剪枝通常会导致显著的质量下降,尤其是在需要多步推理的任务中。为了恢复损失的质量,通常采用监督微调(SFT),但它可能导致灾难性遗忘,从而改变模型学习到的数据分布。因此,解决剪枝和SFT带来的退化对于保持原始模型的质量至关重要。在这项工作中,我们利用自数据蒸馏微调来应对这些挑战。我们的方法利用原始的、未剪枝的模型来生成一个蒸馏数据集,该数据集保留了语义丰富性,并通过保持与基础模型知识的对齐来减轻灾难性遗忘。经验表明,自数据蒸馏始终优于标准SFT,在HuggingFace OpenLLM Leaderboard v1上的平均准确率提高了高达8%。具体来说,当在Llama3.1-8B Instruct上剪枝六个解码器块(即32层到26层,将模型大小从8.03B减少到6.72B参数)时,我们的方法保留了原始模型91.2%的准确率,而SFT为81.7%,同时减少了16.3%的实际FLOPs。此外,通过模型合并结合自数据蒸馏模型可以提高质量保持。此外,在推测解码中利用这些剪枝模型可以提高token接受率,从而提高应用环境中的推理效率。
🔬 方法详解
问题定义:论文旨在解决大语言模型剪枝后性能显著下降的问题,尤其是在多步推理任务中。现有的监督微调(SFT)方法虽然可以恢复部分性能,但容易导致灾难性遗忘,即模型忘记了原始数据分布中的知识。因此,如何在剪枝后既能恢复性能,又能避免灾难性遗忘,是本研究要解决的核心问题。
核心思路:论文的核心思路是利用原始的、未剪枝的模型作为“教师”,生成一个蒸馏数据集,然后用这个数据集来微调剪枝后的模型。这个蒸馏数据集包含了原始模型的知识,可以帮助剪枝后的模型在恢复性能的同时,避免灾难性遗忘。通过让剪枝后的模型学习原始模型的“思考方式”,从而更好地保留原始模型的知识。
技术框架:整体框架包含以下几个主要步骤:1) 对原始大语言模型进行结构化剪枝,得到剪枝后的模型。2) 使用原始的、未剪枝的模型,对大量无标签数据进行推理,生成蒸馏数据集。这个数据集包含了原始模型对这些数据的预测结果。3) 使用蒸馏数据集对剪枝后的模型进行微调。
关键创新:论文的关键创新在于使用自数据蒸馏来解决剪枝后模型的性能恢复问题。与传统的监督微调相比,自数据蒸馏可以更好地保留原始模型的知识,避免灾难性遗忘。此外,论文还探索了将自数据蒸馏模型进行模型合并,以及在推测解码中应用剪枝模型,进一步提升了模型的性能和效率。
关键设计:在生成蒸馏数据集时,需要选择合适的无标签数据。论文中可能使用了特定的数据集或数据生成策略,以确保蒸馏数据集的质量。在微调剪枝模型时,需要选择合适的学习率、batch size等超参数。此外,损失函数的设计也很重要,可能需要结合交叉熵损失和KL散度损失,以同时优化模型的预测准确率和与原始模型的相似度。
🖼️ 关键图片
📊 实验亮点
实验结果表明,自数据蒸馏微调在HuggingFace OpenLLM Leaderboard v1上优于标准SFT,平均准确率提升高达8%。在Llama3.1-8B Instruct上剪枝六个解码器块后,自数据蒸馏方法保留了原始模型91.2%的准确率,而SFT为81.7%,同时减少了16.3%的实际FLOPs。此外,模型合并和推测解码等技术进一步提升了模型的性能和效率。
🎯 应用场景
该研究成果可应用于各种需要部署大语言模型的场景,尤其是在资源受限的环境中,如移动设备、边缘计算等。通过剪枝和自数据蒸馏,可以在显著降低模型大小和计算复杂度的同时,保持较高的模型性能。此外,该方法还可以用于知识迁移和模型压缩等领域,具有广泛的应用前景。
📄 摘要(原文)
Large language models have driven significant progress in natural language processing, but their deployment requires substantial compute and memory resources. As models scale, compression techniques become essential for balancing model quality with computational efficiency. Structured pruning, which removes less critical components of the model, is a promising strategy for reducing complexity. However, one-shot pruning often results in significant quality degradation, particularly in tasks requiring multi-step reasoning. To recover lost quality, supervised fine-tuning (SFT) is commonly applied, but it can lead to catastrophic forgetting by shifting the model's learned data distribution. Therefore, addressing the degradation from both pruning and SFT is essential to preserve the original model's quality. In this work, we utilize self-data distilled fine-tuning to address these challenges. Our approach leverages the original, unpruned model to generate a distilled dataset that preserves semantic richness and mitigates catastrophic forgetting by maintaining alignment with the base model's knowledge. Empirically, we demonstrate that self-data distillation consistently outperforms standard SFT, improving average accuracy by up to 8% on the HuggingFace OpenLLM Leaderboard v1. Specifically, when pruning six decoder blocks on Llama3.1-8B Instruct (i.e., 32 to 26 layers, reducing the model size from 8.03B to 6.72B parameters), our method retains 91.2% of the original model's accuracy compared to 81.7% with SFT, while reducing real-world FLOPs by 16.3%. Furthermore, combining self-data distilled models through model merging yields enhanced quality retention. Additionally, leveraging these pruned models in speculative decoding increases token acceptance rates, thereby improving inference efficiency in applied settings.