Inference-Time Diffusion Model Distillation
作者: Geon Yeong Park, Sang Wan Lee, Jong Chul Ye
分类: cs.CV, cs.AI
发布日期: 2024-12-12
备注: Code: https://github.com/geonyeong-park/inference_distillation
🔗 代码/项目: GITHUB
💡 一句话要点
提出Distillation++,通过推理时教师引导优化,提升扩散模型蒸馏性能。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 扩散模型 蒸馏 推理加速 教师引导 分数蒸馏采样 逆向采样 近端优化
📋 核心要点
- 扩散模型蒸馏虽然加速了推理,但性能与原始模型存在差距,且受分布偏移和误差累积影响。
- Distillation++通过在推理时引入教师模型引导的优化,实时改进学生模型的去噪过程。
- 实验表明,Distillation++显著优于现有蒸馏方法,尤其是在采样初期阶段,提升效果明显。
📝 摘要(中文)
扩散模型蒸馏通过将逆向采样过程压缩到更少的步骤来有效加速推理。然而,与预训练的扩散模型相比,这些模型仍然存在性能差距,这种差距因分布偏移和多步采样过程中累积的误差而加剧。为了解决这个问题,我们引入了Distillation++,这是一个新颖的推理时蒸馏框架,通过在采样过程中加入教师引导的细化来缩小这一差距。受到条件采样最新进展的启发,我们的方法将学生模型的采样重塑为一个具有分数蒸馏采样损失(SDS)的近端优化问题。为此,我们将蒸馏优化集成到逆向采样过程中,这可以看作是教师指导,利用预训练的扩散模型驱动学生采样轨迹朝着干净的流形前进。因此,Distillation++实时改进去噪过程,而无需额外的源数据或微调。Distillation++在最先进的蒸馏基线上展示了显著的改进,尤其是在早期采样阶段,使其成为为扩散蒸馏模型量身定制的强大引导采样过程。
🔬 方法详解
问题定义:扩散模型蒸馏旨在加速推理过程,但现有方法在性能上与原始模型存在差距。主要痛点在于,蒸馏模型在减少采样步骤的同时,会引入分布偏移和累积误差,导致生成质量下降。尤其是在早期采样阶段,误差更为明显。
核心思路:Distillation++的核心思路是在推理过程中,利用预训练的教师扩散模型来引导学生模型的采样轨迹。通过将学生模型的采样过程视为一个优化问题,并使用分数蒸馏采样损失(SDS)作为优化目标,使得学生模型的采样过程尽可能地接近教师模型所定义的干净流形。
技术框架:Distillation++框架主要包含以下几个阶段:1) 初始化:使用蒸馏后的学生模型进行初始采样;2) 教师引导:利用预训练的教师模型提供梯度信息,指导学生模型的采样方向;3) 近端优化:将教师引导信息融入到学生模型的采样过程中,通过优化SDS损失来更新学生模型的参数;4) 迭代采样:重复步骤2和3,直到达到预定的采样步数。
关键创新:Distillation++的关键创新在于将蒸馏优化集成到逆向采样过程中,实现了推理时的实时优化。与传统的蒸馏方法不同,Distillation++不需要额外的源数据或微调,而是直接利用预训练的教师模型来指导学生模型的采样过程,从而有效地减少了分布偏移和累积误差。
关键设计:Distillation++的关键设计包括:1) 使用分数蒸馏采样损失(SDS)作为优化目标,确保学生模型的采样轨迹与教师模型一致;2) 将学生模型的采样过程重塑为一个近端优化问题,便于利用梯度信息进行优化;3) 在推理时进行优化,无需额外的训练数据或微调。
🖼️ 关键图片
📊 实验亮点
Distillation++在实验中表现出显著的性能提升,尤其是在早期采样阶段。与现有的蒸馏基线相比,Distillation++能够生成更高质量的图像,并且在相同的采样步数下,能够取得更好的FID (Fréchet Inception Distance) 分数。具体性能数据需要在论文中查找。
🎯 应用场景
Distillation++可应用于各种需要加速扩散模型推理的场景,例如图像生成、图像编辑、视频生成等。该方法能够显著提升生成速度,同时保持较高的生成质量,具有广泛的应用前景。此外,该方法还可以用于模型压缩和知识迁移等领域。
📄 摘要(原文)
Diffusion distillation models effectively accelerate reverse sampling by compressing the process into fewer steps. However, these models still exhibit a performance gap compared to their pre-trained diffusion model counterparts, exacerbated by distribution shifts and accumulated errors during multi-step sampling. To address this, we introduce Distillation++, a novel inference-time distillation framework that reduces this gap by incorporating teacher-guided refinement during sampling. Inspired by recent advances in conditional sampling, our approach recasts student model sampling as a proximal optimization problem with a score distillation sampling loss (SDS). To this end, we integrate distillation optimization during reverse sampling, which can be viewed as teacher guidance that drives student sampling trajectory towards the clean manifold using pre-trained diffusion models. Thus, Distillation++ improves the denoising process in real-time without additional source data or fine-tuning. Distillation++ demonstrates substantial improvements over state-of-the-art distillation baselines, particularly in early sampling stages, positioning itself as a robust guided sampling process crafted for diffusion distillation models. Code: https://github.com/geonyeong-park/inference_distillation.