Model Merging via Multi-Teacher Knowledge Distillation
作者: Seyed Arshan Dalili, Mehrdad Mahdavi
分类: cs.LG, cs.AI
发布日期: 2025-12-24
🔗 代码/项目: GITHUB
💡 一句话要点
提出SAMerging,通过多教师知识蒸馏实现模型合并,提升泛化性能。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 模型合并 知识蒸馏 多任务学习 泛化能力 PAC-Bayes Sharpness-Aware Minimization 深度学习
📋 核心要点
- 现有模型合并方法依赖启发式参数组合,缺乏理论指导,导致性能不稳定且对初始化敏感。
- 论文将模型合并视为多教师知识蒸馏,通过最小化KL散度来优化合并模型,并推导出泛化界限。
- 提出的SAMerging方法利用Sharpness-Aware Minimization寻找平坦最小值,在视觉和NLP任务上取得SOTA性能。
📝 摘要(中文)
模型合并已成为联合多任务学习的一种轻量级替代方案,但合并模型的泛化特性在很大程度上仍未被探索。建立此类理论保证并非易事,因为合并过程通常禁止访问原始训练数据,并且涉及组合在根本上异构的数据分布上微调的模型。由于缺乏对这些动态的原则性理解,当前的方法通常依赖于启发式方法来近似参数的最佳组合。这种依赖在系数缩放中最为关键,系数缩放调节每个微调模型对共享参数的贡献程度。然而,在没有原则性目标来指导选择的情况下,这些方法会导致脆弱的性能,并且对缩放初始化高度敏感。我们通过以下方式解决此差距:(i) 专门为模型合并设置建立一种新颖的、感知平坦度的PAC-Bayes泛化界限。该分析引入了一个“跨任务异质性”项,该项正式捕获了不同的微调模型先验与目标多任务分布之间的不匹配。(ii) 在此理论见解的指导下,我们将模型合并定义为在稀缺的未标记数据上的多教师知识蒸馏。我们正式证明,最小化学生-教师Kullback-Leibler散度可以直接收紧合并模型超额风险的上限。在导出的感知平坦度的界限的指导下,(iii) 我们通过SAMerging来实现此目标,SAMerging是一种采用Sharpness-Aware Minimization (SAM) 来寻找平坦最小值的方法。在经验上,SAMerging在视觉和NLP基准测试中建立了新的最先进水平,取得了显著的性能。代码可在 https://github.com/arshandalili/SAMerging 获得。
🔬 方法详解
问题定义:模型合并旨在将多个在不同任务上微调的模型合并为一个模型,以实现多任务学习的目的。现有方法主要依赖启发式算法来确定各个模型的参数权重,缺乏理论支撑,导致合并后的模型性能不稳定,对参数初始化敏感,且泛化能力难以保证。尤其是在数据异构性较高的情况下,如何有效融合不同模型的知识成为一个挑战。
核心思路:论文的核心思路是将模型合并问题转化为多教师知识蒸馏问题。通过将多个微调后的模型视为教师模型,利用少量未标注数据进行知识蒸馏,训练出一个学生模型,该学生模型即为合并后的模型。这种方法的核心在于利用知识蒸馏的优势,将多个教师模型的知识有效地传递给学生模型,从而提高合并模型的泛化能力。
技术框架:SAMerging的整体框架如下:首先,对多个在不同任务上预训练的模型进行微调。然后,利用少量未标注数据,将这些微调后的模型作为教师模型,使用知识蒸馏的方法训练一个学生模型。在训练过程中,采用Sharpness-Aware Minimization (SAM) 算法来寻找平坦最小值,从而提高模型的泛化能力。整个流程可以概括为:微调 -> 多教师知识蒸馏 -> SAM优化。
关键创新:论文的关键创新在于:(1) 将模型合并问题转化为多教师知识蒸馏问题,为模型合并提供了一种新的视角和方法。(2) 提出了一个针对模型合并的 flatness-aware PAC-Bayes 泛化界限,为模型合并的理论分析提供了基础。(3) 结合知识蒸馏和SAM算法,提出SAMerging方法,有效地提高了合并模型的泛化能力。与现有方法相比,SAMerging具有更强的理论支撑和更好的实验效果。
关键设计:SAMerging的关键设计包括:(1) 使用Kullback-Leibler (KL) 散度作为知识蒸馏的损失函数,以衡量学生模型和教师模型之间的差异。(2) 采用Sharpness-Aware Minimization (SAM) 算法来寻找平坦最小值,SAM通过在参数空间中寻找对扰动不敏感的解来提高模型的泛化能力。(3) 论文还设计了一个“跨任务异质性”项,用于衡量不同任务之间的差异,并在优化过程中加以考虑。
🖼️ 关键图片
📊 实验亮点
SAMerging在多个视觉和NLP基准测试中取得了显著的性能提升,超越了现有的模型合并方法,建立了新的SOTA。实验结果表明,SAMerging能够有效地融合不同模型的知识,提高合并模型的泛化能力,尤其是在数据异构性较高的情况下,优势更加明显。
🎯 应用场景
该研究成果可应用于各种需要多任务学习的场景,例如自动驾驶、智能医疗、自然语言处理等。通过模型合并,可以有效地利用已有的模型资源,降低训练成本,提高模型性能。该方法在资源受限的边缘设备上具有重要的应用价值,可以实现轻量级的多任务学习。
📄 摘要(原文)
Model merging has emerged as a lightweight alternative to joint multi-task learning (MTL), yet the generalization properties of merged models remain largely unexplored. Establishing such theoretical guarantees is non-trivial, as the merging process typically forbids access to the original training data and involves combining fine-tuned models trained on fundamentally heterogeneous data distributions. Without a principled understanding of these dynamics, current methods often rely on heuristics to approximate the optimal combination of parameters. This dependence is most critical in coefficient scaling, the weighting factors that modulate the magnitude of each fine-tuned model's contribution to the shared parameter. However, without a principled objective to guide their selection, these methods lead to brittle performance and are highly sensitive to scaling initialization. We address this gap by (i) establishing a novel flatness-aware PAC-Bayes generalization bound specifically for the model merging setting. This analysis introduces a "cross-task heterogeneity" term that formally captures the mismatch between diverse fine-tuned model priors and the target multi-task distributions. Guided by this theoretical insight, (ii) we frame model merging as multi-teacher knowledge distillation on scarce, unlabeled data. We formally demonstrate that minimizing the student-teacher Kullback-Leibler divergence directly tightens the upper bound on the merged model's excess risk. Guided by the flatness-aware bound derived, (iii) we operationalize this objective via SAMerging, a method that employs Sharpness-Aware Minimization (SAM) to find flat minima. Empirically, SAMerging establishes a new state of the art across vision and NLP benchmarks, achieving remarkable performance. The code is available at https://github.com/arshandalili/SAMerging.