Distribution Backtracking Builds A Faster Convergence Trajectory for Diffusion Distillation
作者: Shengyuan Zhang, Ling Yang, Zejian Li, An Zhao, Chenye Meng, Changyuan Yang, Guang Yang, Zhiyuan Yang, Lingyun Sun
分类: cs.CV
发布日期: 2024-08-28 (更新: 2025-04-17)
备注: Our code is publicly available on https://github.com/SYZhang0805/DisBack
🔗 代码/项目: GITHUB
💡 一句话要点
提出DisBack,通过分布回溯加速扩散模型蒸馏的收敛速度
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 扩散模型 蒸馏 生成模型 收敛加速 分布回溯
📋 核心要点
- 现有扩散模型蒸馏方法忽略了学生模型与教师模型之间的收敛轨迹,导致早期训练阶段分数不匹配。
- DisBack通过记录教师模型的退化路径,反向模拟收敛轨迹,引导学生模型学习中间分布。
- 实验表明,DisBack加速了收敛速度,并在ImageNet 64x64上取得了1.38的FID分数,性能优于现有方法。
📝 摘要(中文)
加速扩散模型的采样速度仍然是一个重要的挑战。最近的分数蒸馏方法将一个重的教师模型蒸馏成一个学生生成器,以实现单步生成,该生成器通过计算两个分数函数在学生模型生成的样本上的差异进行优化。然而,在蒸馏过程的早期阶段存在分数不匹配问题,因为现有方法主要关注使用预训练扩散模型的终点作为教师模型,忽略了学生生成器和教师模型之间的收敛轨迹的重要性。为了解决这个问题,我们通过引入教师模型的整个收敛轨迹来扩展分数蒸馏过程,并提出了分布回溯蒸馏(DisBack)。DisBack由两个阶段组成:退化记录和分布回溯。退化记录旨在获得教师模型的收敛轨迹,它记录了从训练好的教师模型到未训练的初始学生生成器的退化路径。退化路径隐式地表示了教师模型的中间分布,其反向可以被视为从学生生成器到教师模型的收敛轨迹。然后,分布回溯训练一个学生生成器来回溯沿路径的中间分布,以近似教师模型的收敛轨迹。大量的实验表明,DisBack比现有的蒸馏方法实现了更快和更好的收敛,并完成了相当的生成性能,在ImageNet 64x64数据集上的FID分数为1.38。值得注意的是,DisBack易于实现,并且可以推广到现有的蒸馏方法以提高性能。我们的代码已在https://github.com/SYZhang0805/DisBack上公开。
🔬 方法详解
问题定义:现有的扩散模型蒸馏方法,如score distillation,主要依赖于预训练扩散模型的最终状态作为教师模型,而忽略了从初始状态到最终状态的整个收敛过程。这导致在蒸馏的早期阶段,学生模型与教师模型之间存在较大的差异,即分数不匹配问题,从而影响了蒸馏的效率和最终性能。
核心思路:DisBack的核心思路是通过模拟教师模型的收敛轨迹,让学生模型逐步逼近教师模型的中间状态,从而缓解分数不匹配问题。具体来说,DisBack首先记录教师模型的退化路径,然后训练学生模型沿着这条路径回溯,学习教师模型的中间分布。
技术框架:DisBack包含两个主要阶段:退化记录(Degradation Recording)和分布回溯(Distribution Backtracking)。在退化记录阶段,通过逐步向训练好的教师模型添加噪声或进行其他形式的扰动,记录模型性能逐渐下降的路径。这条路径代表了教师模型从训练完成到未训练状态的演变过程。在分布回溯阶段,训练学生模型来逆向追踪这条退化路径,即从初始状态逐步学习教师模型的中间分布,最终达到与教师模型相似的性能。
关键创新:DisBack的关键创新在于引入了教师模型的收敛轨迹的概念,并设计了一种有效的方法来模拟和利用这条轨迹。与传统的蒸馏方法只关注教师模型的最终状态不同,DisBack关注整个学习过程,从而能够更有效地引导学生模型的训练。
关键设计:退化记录阶段,可以通过多种方式实现,例如逐步增加噪声的强度,或者逐步减少模型的参数量。分布回溯阶段,可以使用对抗训练、知识蒸馏等技术来训练学生模型。损失函数的设计需要考虑如何衡量学生模型与教师模型中间分布的相似度,例如可以使用KL散度或Wasserstein距离。
🖼️ 关键图片
📊 实验亮点
DisBack在ImageNet 64x64数据集上取得了显著的成果,FID分数为1.38,表明其生成图像的质量很高。实验结果还表明,DisBack比现有的蒸馏方法收敛速度更快,这意味着可以在更短的时间内训练出性能更好的学生模型。此外,DisBack易于实现,可以推广到现有的蒸馏方法以提高性能。
🎯 应用场景
DisBack可应用于各种需要加速扩散模型采样速度的场景,例如图像生成、视频生成、音频生成等。通过蒸馏,可以将计算量大的教师模型压缩成计算量小的学生模型,从而在资源受限的设备上实现快速生成。此外,DisBack还可以与其他蒸馏技术结合使用,进一步提高蒸馏的效率和性能。
📄 摘要(原文)
Accelerating the sampling speed of diffusion models remains a significant challenge. Recent score distillation methods distill a heavy teacher model into a student generator to achieve one-step generation, which is optimized by calculating the difference between the two score functions on the samples generated by the student model. However, there is a score mismatch issue in the early stage of the distillation process, because existing methods mainly focus on using the endpoint of pre-trained diffusion models as teacher models, overlooking the importance of the convergence trajectory between the student generator and the teacher model. To address this issue, we extend the score distillation process by introducing the entire convergence trajectory of teacher models and propose Distribution Backtracking Distillation (DisBack). DisBask is composed of two stages: Degradation Recording and Distribution Backtracking. Degradation Recording is designed to obtain the convergence trajectory of the teacher model, which records the degradation path from the trained teacher model to the untrained initial student generator. The degradation path implicitly represents the teacher model's intermediate distributions, and its reverse can be viewed as the convergence trajectory from the student generator to the teacher model. Then Distribution Backtracking trains a student generator to backtrack the intermediate distributions along the path to approximate the convergence trajectory of teacher models. Extensive experiments show that DisBack achieves faster and better convergence than the existing distillation method and accomplishes comparable generation performance, with FID score of 1.38 on ImageNet 64x64 dataset. Notably, DisBack is easy to implement and can be generalized to existing distillation methods to boost performance. Our code is publicly available on https://github.com/SYZhang0805/DisBack.