JADAI: Jointly Amortizing Adaptive Design and Bayesian Inference

📄 arXiv: 2512.22999v1 📥 PDF

作者: Niels Bracher, Lars Kühmichel, Desi R. Ivanova, Xavier Intes, Paul-Christian Bürkner, Stefan T. Radev

分类: stat.ML, cs.AI, cs.LG

发布日期: 2025-12-28


💡 一句话要点

JADAI:联合学习自适应设计与贝叶斯推断,提升参数估计的信息增益。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 自适应设计 贝叶斯推断 主动学习 深度学习 参数估计

📋 核心要点

  1. 现有参数估计方法在设计变量优化以最大化信息增益方面存在不足,尤其是在高维和多模态后验分布场景下。
  2. JADAI框架通过联合学习策略网络、历史网络和推断网络,端到端地分摊贝叶斯自适应设计和推断过程。
  3. 实验结果表明,JADAI在标准自适应设计基准测试中表现出色,能够有效近似高维和多模态后验分布。

📝 摘要(中文)

本文研究了参数估计问题,其中设计变量可以被主动优化以最大化信息增益。为此,我们提出了JADAI,一个通过端到端训练策略网络、历史网络和推断网络来联合分摊贝叶斯自适应设计和推断的框架。这些网络最小化一个通用损失函数,该函数聚合了实验序列中后验误差的增量减少。推断网络通过基于扩散的后验估计器实例化,该估计器可以近似每个实验步骤中的高维和多模态后验。在标准自适应设计基准测试中,JADAI实现了优越或具有竞争力的性能。

🔬 方法详解

问题定义:论文旨在解决参数估计问题,特别是在主动学习或实验设计场景下,如何优化设计变量以最大化关于未知参数的信息增益。现有方法可能难以处理高维参数空间、多模态后验分布,或者无法有效地将自适应设计和贝叶斯推断过程整合起来。这导致次优的实验设计和参数估计结果。

核心思路:JADAI的核心思路是联合学习自适应设计策略和贝叶斯推断过程。通过训练一个策略网络来选择最佳的实验设计,同时训练一个推断网络来估计参数的后验分布。历史网络用于整合之前的实验信息,从而实现序列化的自适应设计。这种联合学习的方式允许模型在实验过程中不断学习和改进,从而更有效地获取信息。

技术框架:JADAI框架包含三个主要模块:策略网络、历史网络和推断网络。策略网络接收当前后验分布的表示,并输出下一个实验设计。历史网络用于编码之前的实验设计和观测结果,并将其传递给策略网络和推断网络。推断网络接收历史信息和当前实验的观测结果,并估计参数的后验分布。整个框架通过端到端的方式进行训练,以最小化一个通用损失函数,该函数衡量了实验序列中后验误差的减少。

关键创新:JADAI的关键创新在于将自适应设计和贝叶斯推断过程联合分摊。传统的自适应设计方法通常需要手动设计实验策略或使用启发式算法,而JADAI通过学习的方式自动优化实验策略。此外,JADAI使用基于扩散的后验估计器,可以有效地近似高维和多模态后验分布,这对于复杂的参数估计问题至关重要。

关键设计:JADAI使用深度神经网络来实现策略网络、历史网络和推断网络。策略网络可以使用各种强化学习算法进行训练,例如策略梯度或Q-learning。历史网络可以使用循环神经网络(RNN)或Transformer来编码序列化的实验信息。推断网络使用基于扩散的模型来生成后验分布的样本。损失函数通常包括一个衡量后验误差的项,例如KL散度或Wasserstein距离,以及一个正则化项,以防止过拟合。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

JADAI在多个标准自适应设计基准测试中取得了优越或具有竞争力的性能。具体来说,JADAI能够更有效地减少后验误差,并更快地收敛到真实的参数值。与传统的自适应设计方法相比,JADAI能够更好地处理高维和多模态后验分布,并且能够自动学习最优的实验策略。这些结果表明JADAI是一个有效的自适应设计和贝叶斯推断框架。

🎯 应用场景

JADAI可应用于科学实验设计、医疗诊断、机器人探索等领域。例如,在药物研发中,可以优化临床试验的设计,以更有效地评估药物的疗效。在机器人探索中,可以指导机器人选择最佳的探索路径,以更快地构建环境地图。该研究有助于提高实验效率、降低成本,并加速科学发现。

📄 摘要(原文)

We consider problems of parameter estimation where design variables can be actively optimized to maximize information gain. To this end, we introduce JADAI, a framework that jointly amortizes Bayesian adaptive design and inference by training a policy, a history network, and an inference network end-to-end. The networks minimize a generic loss that aggregates incremental reductions in posterior error along experimental sequences. Inference networks are instantiated with diffusion-based posterior estimators that can approximate high-dimensional and multimodal posteriors at every experimental step. Across standard adaptive design benchmarks, JADAI achieves superior or competitive performance.