Maximum Entropy Inverse Reinforcement Learning of Diffusion Models with Energy-Based Models

📄 arXiv: 2407.00626v2 📥 PDF

作者: Sangwoong Yoon, Himchan Hwang, Dohyun Kwon, Yung-Kyun Noh, Frank C. Park

分类: cs.LG, cs.AI

发布日期: 2024-06-30 (更新: 2024-10-31)

备注: NeurIPS 2024 Oral Presentation. Code is released at https://github.com/swyoon/Diffusion-by-MaxEntIRL


💡 一句话要点

提出基于最大熵逆强化学习的扩散模型训练方法,提升生成质量并加速采样。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 扩散模型 逆强化学习 能量模型 生成模型 最大熵 异常检测 动态规划

📋 核心要点

  1. 扩散模型生成高质量样本需要大量时间步,降低了生成效率,现有方法难以在少量步数下保持生成质量。
  2. 论文提出DxMI,通过最大熵逆强化学习框架,联合训练扩散模型和能量模型,利用能量模型提供的密度估计指导扩散模型训练。
  3. 实验表明,DxMI能显著提升扩散模型在少量步数下的生成质量,并稳定EBM训练,提升异常检测性能。

📝 摘要(中文)

本文提出了一种基于最大熵逆强化学习(IRL)的方法,用于提升扩散生成模型的样本质量,尤其是在生成时间步数较少的情况下。类似于IRL基于从专家演示中学到的奖励函数来训练策略,我们使用从训练数据估计的对数概率密度来训练(或微调)扩散模型。由于我们采用基于能量的模型(EBM)来表示对数密度,因此我们的方法归结为扩散模型和EBM的联合训练。我们提出的IRL公式,命名为Diffusion by Maximum Entropy IRL (DxMI),是一个极小极大问题,当两个模型都收敛到数据分布时达到平衡。熵最大化在DxMI中起着关键作用,促进了扩散模型的探索并确保了EBM的收敛。我们还提出了Diffusion by Dynamic Programming (DxDP),一种用于扩散模型的新型强化学习算法,作为DxMI中的一个子程序。DxDP通过将原始问题转换为最优控制公式,其中价值函数代替了时间上的反向传播,从而使DxMI中的扩散模型更新更加高效。我们的实验研究表明,使用DxMI微调的扩散模型可以在少至4步和10步内生成高质量的样本。此外,DxMI无需MCMC即可训练EBM,从而稳定EBM训练动态并提高异常检测性能。

🔬 方法详解

问题定义:论文旨在解决扩散模型生成样本时,在时间步数较少的情况下,生成质量下降的问题。现有的扩散模型通常需要大量的采样步骤才能生成高质量的样本,这限制了其在实际应用中的效率。因此,如何在保证生成质量的前提下,减少采样步骤是本研究要解决的核心问题。

核心思路:论文的核心思路是将扩散模型的训练过程视为一个逆强化学习(IRL)问题。通过学习一个能量模型(EBM)来估计数据的对数概率密度,并将该密度作为奖励函数来指导扩散模型的训练。最大熵原则被引入,以鼓励扩散模型探索更广泛的样本空间,并确保能量模型的收敛。

技术框架:整体框架包含两个主要部分:扩散模型和能量模型。首先,使用训练数据训练一个能量模型,使其能够准确地估计数据的对数概率密度。然后,将该密度作为奖励函数,使用逆强化学习算法(DxMI)来微调扩散模型。DxMI算法包含一个子程序DxDP,它是一种基于动态规划的强化学习算法,用于高效地更新扩散模型。整个过程是一个极小极大博弈,扩散模型试图生成更符合数据分布的样本,而能量模型试图更准确地估计数据分布。

关键创新:论文的关键创新在于将逆强化学习的思想引入到扩散模型的训练中,并提出了DxMI算法。与传统的扩散模型训练方法不同,DxMI不是直接优化生成样本与真实样本之间的距离,而是通过学习一个能量模型来间接地指导扩散模型的训练。此外,DxDP算法通过动态规划的方式,避免了在时间上的反向传播,从而提高了训练效率。

关键设计:能量模型采用标准的EBM结构,损失函数包括能量项和对比散度项。扩散模型采用标准的扩散模型结构,但其更新方式由DxMI算法决定。DxMI算法的关键在于如何将能量模型的输出转化为扩散模型的更新信号。具体来说,DxDP算法通过计算价值函数来估计每个时间步的奖励,并使用该价值函数来更新扩散模型的参数。最大熵正则化被添加到目标函数中,以鼓励扩散模型探索更广泛的样本空间。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,使用DxMI微调的扩散模型可以在少至4步和10步内生成高质量的样本,显著减少了生成时间。此外,DxMI在训练EBM时无需MCMC,稳定了EBM的训练动态,并在异常检测任务中取得了优于现有方法的性能。

🎯 应用场景

该研究成果可应用于图像生成、音频合成、视频生成等领域,尤其是在对生成速度有较高要求的场景下。例如,可以用于实时图像编辑、快速生成游戏素材、以及加速科学计算中的数据生成过程。此外,该方法还可以用于异常检测,通过训练一个能够准确估计数据分布的能量模型,可以有效地识别出与正常数据分布不符的异常样本。

📄 摘要(原文)

We present a maximum entropy inverse reinforcement learning (IRL) approach for improving the sample quality of diffusion generative models, especially when the number of generation time steps is small. Similar to how IRL trains a policy based on the reward function learned from expert demonstrations, we train (or fine-tune) a diffusion model using the log probability density estimated from training data. Since we employ an energy-based model (EBM) to represent the log density, our approach boils down to the joint training of a diffusion model and an EBM. Our IRL formulation, named Diffusion by Maximum Entropy IRL (DxMI), is a minimax problem that reaches equilibrium when both models converge to the data distribution. The entropy maximization plays a key role in DxMI, facilitating the exploration of the diffusion model and ensuring the convergence of the EBM. We also propose Diffusion by Dynamic Programming (DxDP), a novel reinforcement learning algorithm for diffusion models, as a subroutine in DxMI. DxDP makes the diffusion model update in DxMI efficient by transforming the original problem into an optimal control formulation where value functions replace back-propagation in time. Our empirical studies show that diffusion models fine-tuned using DxMI can generate high-quality samples in as few as 4 and 10 steps. Additionally, DxMI enables the training of an EBM without MCMC, stabilizing EBM training dynamics and enhancing anomaly detection performance.