Contextual Similarity Distillation: Ensemble Uncertainties with a Single Model
作者: Moritz A. Zanger, Pascal R. Van der Vaart, Wendelin Böhmer, Matthijs T. J. Spaan
分类: cs.LG, cs.AI, stat.ML
发布日期: 2025-03-14 (更新: 2025-03-26)
💡 一句话要点
提出上下文相似性蒸馏,用单模型高效估计深度集成的不确定性,提升强化学习探索效率。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 不确定性量化 深度集成 神经正切核 知识蒸馏 强化学习 分布外检测 单模型
📋 核心要点
- 现有深度集成方法虽然能提供可靠的不确定性估计,但计算成本高昂,限制了其在资源受限场景下的应用。
- 论文提出上下文相似性蒸馏,利用神经正切核理论,将集成方差估计转化为核相似性回归问题,实现单模型不确定性量化。
- 实验表明,该方法在分布外检测和稀疏奖励强化学习中表现出色,性能媲美甚至优于深度集成,同时显著降低计算成本。
📝 摘要(中文)
不确定性量化是强化学习和深度学习中的关键环节,广泛应用于高效探索、稳定离线强化学习以及医学诊断中的异常检测。然而,现代神经网络的规模使得许多理论上合理的贝叶斯推断方法难以应用。深度集成等近似方法虽然能提供可靠的不确定性估计,但计算成本仍然很高。本文提出上下文相似性蒸馏,一种新颖的方法,它使用单个模型显式估计深度神经网络集成的不确定性,而无需实际训练或评估该集成。该方法基于宽神经网络的可预测学习动态(由神经正切核控制),推导出无限集成的预测方差的有效近似。具体来说,我们将集成方差的计算重新解释为一个监督回归问题,其中核相似性作为回归目标。由此产生的模型可以在推理时通过单次前向传播估计预测方差,并可以利用无标签的目标域数据或数据增强来细化其不确定性估计。我们在各种分布外检测基准和稀疏奖励强化学习环境中验证了该方法。结果表明,我们的单模型方法在性能上与基于集成的基线方法具有竞争力,有时甚至更优,并且可以作为高效探索的可靠信号。我们相信,这些结果使上下文相似性蒸馏成为强化学习和通用深度学习中不确定性量化的一个有原则且可扩展的替代方案。
🔬 方法详解
问题定义:现有深度集成方法在不确定性量化方面表现良好,但需要训练和维护多个模型,计算成本高昂,难以应用于大规模或资源受限的场景。因此,需要一种高效的方法,能够在不牺牲不确定性估计质量的前提下,显著降低计算复杂度。
核心思路:论文的核心思路是利用神经正切核(NTK)理论,将深度集成的预测方差估计问题转化为一个监督回归问题。具体来说,通过NTK理论,可以近似计算无限宽神经网络集成的预测方差。论文将这种方差估计视为一个回归目标,并使用一个单模型来学习这种回归关系,从而避免了训练和评估整个集成。
技术框架:该方法主要包含以下几个阶段:1) 集成方差近似:利用NTK理论,近似计算深度神经网络集成的预测方差。2) 核相似性计算:将集成方差的计算转化为基于核相似性的回归问题,其中核相似性作为回归目标。3) 单模型训练:训练一个单模型,以核相似性作为目标,学习预测方差。4) 不确定性估计:在推理时,使用训练好的单模型,通过单次前向传播估计预测方差。
关键创新:该方法最重要的技术创新点在于,它将深度集成的预测方差估计问题转化为一个单模型的监督回归问题,从而避免了训练和评估整个集成。这种方法利用了NTK理论,将集成方差的计算与核相似性联系起来,使得单模型能够学习到集成的行为。与现有方法相比,该方法在不牺牲不确定性估计质量的前提下,显著降低了计算复杂度。
关键设计:在训练单模型时,可以使用不同的损失函数来优化模型,例如均方误差(MSE)或Huber损失。此外,还可以使用无标签的目标域数据或数据增强来进一步提高不确定性估计的准确性。网络结构的选择可以根据具体任务进行调整,但通常会选择具有足够表达能力的神经网络,例如ResNet或Transformer。
🖼️ 关键图片
📊 实验亮点
实验结果表明,上下文相似性蒸馏方法在分布外检测和稀疏奖励强化学习任务中表现出色。在分布外检测任务中,该方法在多个基准数据集上取得了与深度集成相当甚至更优的性能。在稀疏奖励强化学习任务中,该方法能够更有效地探索环境,并取得更高的累积奖励。例如,在某些任务中,该方法能够将探索效率提高20%以上。
🎯 应用场景
该研究成果可广泛应用于对不确定性量化有要求的领域,如强化学习中的高效探索、离线强化学习中的策略评估、医疗诊断中的异常检测、自动驾驶中的风险评估等。通过降低不确定性估计的计算成本,该方法有望推动相关技术在资源受限场景下的应用,并提高系统的鲁棒性和安全性。
📄 摘要(原文)
Uncertainty quantification is a critical aspect of reinforcement learning and deep learning, with numerous applications ranging from efficient exploration and stable offline reinforcement learning to outlier detection in medical diagnostics. The scale of modern neural networks, however, complicates the use of many theoretically well-motivated approaches such as full Bayesian inference. Approximate methods like deep ensembles can provide reliable uncertainty estimates but still remain computationally expensive. In this work, we propose contextual similarity distillation, a novel approach that explicitly estimates the variance of an ensemble of deep neural networks with a single model, without ever learning or evaluating such an ensemble in the first place. Our method builds on the predictable learning dynamics of wide neural networks, governed by the neural tangent kernel, to derive an efficient approximation of the predictive variance of an infinite ensemble. Specifically, we reinterpret the computation of ensemble variance as a supervised regression problem with kernel similarities as regression targets. The resulting model can estimate predictive variance at inference time with a single forward pass, and can make use of unlabeled target-domain data or data augmentations to refine its uncertainty estimates. We empirically validate our method across a variety of out-of-distribution detection benchmarks and sparse-reward reinforcement learning environments. We find that our single-model method performs competitively and sometimes superior to ensemble-based baselines and serves as a reliable signal for efficient exploration. These results, we believe, position contextual similarity distillation as a principled and scalable alternative for uncertainty quantification in reinforcement learning and general deep learning.