Amortized Active Causal Induction with Deep Reinforcement Learning
作者: Yashas Annadani, Panagiotis Tigas, Stefan Bauer, Adam Foster
分类: cs.LG, cs.AI
发布日期: 2024-05-26
💡 一句话要点
提出CAASL,利用深度强化学习进行主动因果结构学习,无需似然函数。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 主动因果结构学习 深度强化学习 因果推断 Transformer网络 干预设计
📋 核心要点
- 现有主动因果结构学习方法通常依赖于似然函数,计算成本高昂且难以实时应用。
- CAASL使用深度强化学习训练一个基于Transformer的摊销网络,直接学习干预策略,避免了似然函数计算。
- 实验表明,CAASL在合成数据和基因表达模拟中,能更准确地估计因果图,并具有良好的泛化能力。
📝 摘要(中文)
本文提出了一种因果摊销主动结构学习(CAASL)方法,该方法是一种主动干预设计策略,可以选择自适应、实时的干预,并且不需要访问似然函数。该策略是一个基于Transformer的摊销网络,通过强化学习在设计环境的模拟器上进行训练,并使用奖励函数来衡量真实因果图与从收集的数据推断出的因果图后验之间的接近程度。在合成数据和单细胞基因表达模拟器上,我们通过实验证明,通过我们的策略获取的数据比其他策略能更好地估计底层因果图。我们的设计策略成功地实现了训练环境分布上的摊销干预设计,同时也很好地推广到测试时设计环境中的分布偏移。此外,我们的策略还展示了对维度高于训练期间的设计环境以及它没有训练过的干预类型的极佳零样本泛化能力。
🔬 方法详解
问题定义:论文旨在解决主动因果结构学习问题,即如何在有限的干预预算下,选择最优的干预策略,以最快地学习到真实的因果图结构。现有方法,如基于似然函数的方法,计算复杂度高,难以处理大规模数据和实时场景。此外,这些方法通常需要对数据分布做出强假设,泛化能力有限。
核心思路:论文的核心思路是利用深度强化学习,直接学习一个干预策略。该策略能够根据当前已观测到的数据,动态地选择下一步要进行的干预。通过强化学习,策略可以学习到如何最大化信息增益,从而更快地收敛到真实的因果图结构。这种方法避免了对似然函数的依赖,降低了计算复杂度,并提高了泛化能力。
技术框架:CAASL的整体框架包括以下几个主要模块:1) 环境模拟器:用于模拟因果干预的过程,生成观测数据。2) 策略网络:基于Transformer的神经网络,输入是当前观测数据,输出是下一步要进行的干预。3) 因果图推断模块:根据观测数据推断因果图的后验分布。4) 奖励函数:衡量推断出的因果图与真实因果图之间的差异,用于训练策略网络。整个流程是,策略网络根据当前观测数据选择干预,环境模拟器执行干预并生成新的观测数据,因果图推断模块更新因果图后验,奖励函数计算奖励,策略网络根据奖励进行更新。
关键创新:CAASL的关键创新在于使用深度强化学习来学习主动干预策略,从而避免了对似然函数的依赖。此外,使用Transformer网络作为策略网络,可以有效地处理序列数据和长程依赖关系。另一个创新点是,该方法具有良好的泛化能力,可以推广到维度更高的环境和未见过的干预类型。
关键设计:策略网络采用Transformer结构,输入是观测数据的序列,输出是干预动作的概率分布。奖励函数采用结构汉明距离,衡量推断出的因果图与真实因果图之间的差异。强化学习算法采用PPO(Proximal Policy Optimization)。训练过程中,使用多个环境并行训练,以提高训练效率。为了提高泛化能力,采用了数据增强和正则化技术。
🖼️ 关键图片
📊 实验亮点
实验结果表明,CAASL在合成数据和单细胞基因表达模拟器上,能够比其他主动学习策略更准确地估计底层因果图。尤其是在高维数据和分布偏移的情况下,CAASL表现出更强的鲁棒性和泛化能力。此外,CAASL还展示了对未见过的干预类型的零样本泛化能力。
🎯 应用场景
CAASL可应用于基因调控网络推断、社交网络分析、推荐系统优化等领域。通过主动干预,可以更有效地发现因果关系,从而更好地理解复杂系统,并进行精准决策。例如,在药物研发中,可以利用CAASL选择最优的药物组合,以达到最佳的治疗效果。
📄 摘要(原文)
We present Causal Amortized Active Structure Learning (CAASL), an active intervention design policy that can select interventions that are adaptive, real-time and that does not require access to the likelihood. This policy, an amortized network based on the transformer, is trained with reinforcement learning on a simulator of the design environment, and a reward function that measures how close the true causal graph is to a causal graph posterior inferred from the gathered data. On synthetic data and a single-cell gene expression simulator, we demonstrate empirically that the data acquired through our policy results in a better estimate of the underlying causal graph than alternative strategies. Our design policy successfully achieves amortized intervention design on the distribution of the training environment while also generalizing well to distribution shifts in test-time design environments. Further, our policy also demonstrates excellent zero-shot generalization to design environments with dimensionality higher than that during training, and to intervention types that it has not been trained on.