Amortizing Causal Sensitivity Analysis via Prior Data-Fitted Networks

📄 arXiv: 2605.10590v1 📥 PDF

作者: Emil Javurek, Dennis Frauen, Marie Brockschmidt, Jonas Schweisthal, Stefan Feuerriegel

分类: stat.ML, cs.LG

发布日期: 2026-05-11


💡 一句话要点

提出基于先验数据拟合网络的摊销化因果敏感性分析方法,实现高效因果推断

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 因果推断 敏感性分析 摊销化学习 上下文学习 基础模型 拉格朗日优化 未观测混杂

📋 核心要点

  1. 现有因果敏感性分析方法依赖单实例计算,导致在面对不同数据集、查询条件或敏感度参数时,重复计算开销巨大,难以满足实时分析需求。
  2. 论文提出一种摊销化学习框架,通过先验数据拟合网络将敏感度边界计算转化为上下文学习任务,利用拉格朗日标量化技术自动生成训练标签。
  3. 实验证明该方法在保持分析精度的同时,推理速度较传统方法提升了几个数量级,为因果推断领域提供了首个高效的基础模型解决方案。

📝 摘要(中文)

因果敏感性分析旨在量化存在未观测混杂因素时的因果效应边界。然而,现有方法多为单实例过程,每当数据集、因果查询、敏感度水平或处理变量发生变化时,均需重新计算,计算成本高昂。本文提出了一种基于先验数据拟合网络(Prior-Data Fitted Networks)的上下文学习(In-context Learning)方法,实现了因果敏感性分析的摊销化。核心挑战在于训练数据中无法直接获取敏感度边界,为此,作者开发了一种适用于广义处理敏感度模型的通用先验数据构建方法。该方法通过拉格朗日标量化目标函数,在因果效应极值优化与敏感度模型违约之间进行权衡,从而生成训练标签,避免了针对特定模型的解析推导。理论上,该方法在凸性和线性条件下可恢复完整的帕累托前沿。实验表明,该方法在多种数据集和查询场景下,推理速度比传统方法快几个数量级,是首个用于因果敏感性分析的基础模型。

🔬 方法详解

问题定义:因果敏感性分析旨在评估未观测混杂因素对因果效应估计的影响。现有方法通常针对特定实例进行优化,当因果查询或敏感度参数改变时,必须重新运行复杂的优化过程,导致计算效率低下,无法支持大规模或动态的因果分析需求。

核心思路:论文引入“摊销化”(Amortization)思想,通过训练一个能够学习敏感度边界映射关系的神经网络,将原本需要实时优化的过程转化为前向推理过程,从而实现计算开销的显著降低。

技术框架:整体框架基于先验数据拟合网络。首先,通过构建通用的先验数据生成机制,模拟不同敏感度场景下的因果效应边界;其次,利用Transformer或类似架构作为基础模型,通过上下文学习处理输入的因果查询与敏感度参数;最后,通过模型推理直接输出因果效应的上下界。

关键创新:最重要的创新在于提出了一种无需模型特定解析推导的通用标签生成方法。通过拉格朗日标量化目标函数,在因果效应极值优化与敏感度模型违约之间建立权衡,使得模型能够学习到通用的敏感度边界函数,而非针对单一模型的特化解。

关键设计:关键技术细节包括拉格朗日乘子的动态调整,以确保在训练过程中能够覆盖完整的帕累托前沿。此外,该方法在满足标准凸性和线性条件下,能够保证所学习到的边界与理论最优解的一致性,从而在保证分析鲁棒性的前提下实现高效推理。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验在多个基准数据集上进行了广泛评估,结果显示该方法在处理不同因果查询和敏感度水平时,推理速度较传统单实例优化方法提升了几个数量级。同时,该模型在保持与传统方法高度一致的边界估计精度的基础上,展现了极强的泛化能力,验证了其作为因果敏感性分析基础模型的有效性。

🎯 应用场景

该方法适用于需要快速评估因果效应稳健性的领域,如医疗决策支持、政策效果评估及经济学分析。在这些场景中,研究人员常需在不同假设条件下反复测试因果结论,该方法能大幅缩短分析周期,提升决策效率,并为大规模因果推断任务提供自动化工具。

📄 摘要(原文)

Causal sensitivity analysis aims to provide bounds for causal effect estimates in the presence of unobserved confounding. However, existing methods for causal sensitivity analysis are per-instance procedures, meaning that changes to the dataset, causal query, sensitivity level, or treatment require new computation. Here, we instead present an in-context learning approach. Specifically, we propose an amortized approach to causal sensitivity analysis based on prior-data fitted networks. A key challenge is that the sensitivity bounds are not directly available when sampling training data. To address this, we develop a general prior-data construction that is applicable across the class of generalized treatment sensitivity models. Our construction involves a Lagrangian scalarization of the objective to generate training labels for the bounds through a tradeoff between causal effect min/max-imization and sensitivity model violation, which avoids model-specific analytical derivations. We further show that, under standard convexity and linearity conditions, our objective recovers the full Pareto frontier of solutions. Empirically, we demonstrate our amortized approach across various datasets, causal queries, and sensitivity levels, where our approach achieves a test-time computation that is orders of magnitude faster than per-instance methods. To the best of our knowledge, ours is the first foundation model for in-context learning for causal sensitivity analysis.