Amortizing intractable inference in diffusion models for vision, language, and control

📄 arXiv: 2405.20971v2 📥 PDF

作者: Siddarth Venkatraman, Moksh Jain, Luca Scimeca, Minsu Kim, Marcin Sendera, Mohsin Hasan, Luke Rowe, Sarthak Mittal, Pablo Lemos, Emmanuel Bengio, Alexandre Adam, Jarrid Rector-Brooks, Yoshua Bengio, Glen Berseth, Nikolay Malkin

分类: cs.LG, cs.CV

发布日期: 2024-05-31 (更新: 2025-01-13)

备注: NeurIPS 2024; code: https://github.com/GFNOrg/diffusion-finetuning


💡 一句话要点

提出相对轨迹平衡以解决扩散模型后验推断问题

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 扩散模型 后验推断 无数据学习 深度强化学习 多模态数据 生成模型 离线强化学习

📋 核心要点

  1. 现有方法在扩散模型中处理后验推断时仅能近似解决,面临效率和准确性挑战。
  2. 论文提出通过相对轨迹平衡这一无数据学习目标,改进扩散模型的后验采样能力。
  3. 实验结果显示,该方法在视觉、语言和多模态数据处理上均表现出显著的性能提升。

📝 摘要(中文)

扩散模型在视觉、语言和强化学习中已成为有效的分布估计器,但作为下游任务的先验使用时面临难以处理的后验推断问题。本文研究了在扩散生成模型先验和黑箱约束函数的框架下,如何通过无数据学习目标相对轨迹平衡来进行后验采样。我们证明了该方法的渐近正确性,并展示了其在视觉(分类器引导)、语言(离散扩散LLM下的填充)和多模态数据(文本到图像生成)中的广泛潜力。此外,我们还将相对轨迹平衡应用于基于分数的行为先验的连续控制问题,在离线强化学习基准上取得了最先进的结果。

🔬 方法详解

问题定义:本文旨在解决扩散模型在下游任务中后验推断的难题。现有方法在处理复杂后验时效率低下,且仅能在有限情况下近似解决。

核心思路:论文提出的相对轨迹平衡方法通过无数据学习目标来优化扩散模型的后验采样,利用生成流网络的视角,结合深度强化学习技术以提升模式覆盖率。

技术框架:整体架构包括扩散生成模型先验和黑箱约束函数,首先通过相对轨迹平衡进行训练,然后在不同任务中进行后验采样。主要模块包括生成模型、约束函数和强化学习策略。

关键创新:相对轨迹平衡是本文的核心创新,与现有方法相比,它提供了一种更为准确和高效的后验推断方式,能够在更广泛的场景中应用。

关键设计:在训练过程中,采用特定的损失函数来优化模型参数,确保生成模型能够有效地从后验中采样,同时设计了适应性强的网络结构以支持多模态数据处理。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,使用相对轨迹平衡方法在视觉任务中实现了比传统方法更高的分类准确率,在语言任务中填充效果显著提升,且在离线强化学习基准上达到了最先进的性能,展示了该方法的广泛适用性和优越性。

🎯 应用场景

该研究的潜在应用领域包括计算机视觉中的图像生成、自然语言处理中的文本填充以及多模态数据的生成任务。通过提高后验推断的效率和准确性,未来可在智能助手、自动内容生成和机器人控制等多个领域产生重要影响。

📄 摘要(原文)

Diffusion models have emerged as effective distribution estimators in vision, language, and reinforcement learning, but their use as priors in downstream tasks poses an intractable posterior inference problem. This paper studies amortized sampling of the posterior over data, $\mathbf{x}\sim p^{\rm post}(\mathbf{x})\propto p(\mathbf{x})r(\mathbf{x})$, in a model that consists of a diffusion generative model prior $p(\mathbf{x})$ and a black-box constraint or likelihood function $r(\mathbf{x})$. We state and prove the asymptotic correctness of a data-free learning objective, relative trajectory balance, for training a diffusion model that samples from this posterior, a problem that existing methods solve only approximately or in restricted cases. Relative trajectory balance arises from the generative flow network perspective on diffusion models, which allows the use of deep reinforcement learning techniques to improve mode coverage. Experiments illustrate the broad potential of unbiased inference of arbitrary posteriors under diffusion priors: in vision (classifier guidance), language (infilling under a discrete diffusion LLM), and multimodal data (text-to-image generation). Beyond generative modeling, we apply relative trajectory balance to the problem of continuous control with a score-based behavior prior, achieving state-of-the-art results on benchmarks in offline reinforcement learning.