Calibrated and Robust Foundation Models for Vision-Language and Medical Image Tasks Under Distribution Shift
作者: Behraj Khan, Tahir Qasim Syed, Nouman M. Durrani, Bilal Naseem, Shabir Ahmad, Rizwan Qureshi
分类: cs.CV, cs.LG
发布日期: 2025-07-12 (更新: 2025-07-20)
💡 一句话要点
提出StaRFM,融合FIP和CMP,提升Foundation Model在分布偏移下的鲁棒性和校准性
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: Foundation Model 分布偏移 置信度校准 Fisher信息惩罚 医学图像分割
📋 核心要点
- 现有Foundation Model在分布偏移和置信度失调下表现不佳,限制了其在实际场景中的应用。
- StaRFM融合Fisher信息惩罚(FIP)和置信度失调惩罚(CMP),同时解决嵌入偏移和校准问题。
- 实验表明,StaRFM在视觉和医学图像任务上显著提升了准确率、分割精度和校准性能。
📝 摘要(中文)
Foundation Model,如CLIP和SAM,通过小样本迁移学习推动了计算机视觉和医学图像处理的发展,尤其是在数据有限的计算机辅助药物设计(CADD)中。然而,它们的部署面临两个主要挑战:分布偏移(预训练和后训练数据分布不同,例如,由于中心间图像采集差异)和置信度失调(导致过度自信的错误)。视觉-语言模型(如CLIP)遭受2D嵌入偏移(图像-文本未对齐),而医学模型(如SAM)遇到3D领域偏移(如扫描仪变化)和体素级校准需求。现有解决方案是特定于领域的。我们提出StaRFM,融合Fisher信息惩罚(FIP)和置信度失调惩罚(CMP),以应对这两个挑战。它应用FIP(通过patch-wise正则化扩展到3D)来减少嵌入偏移,并重新制定CMP以进行体素级预测,从而校准分割不确定性。我们推导了PAC-Bayes界限。FIP通过Fisher-Rao范数控制泛化,CMP通过Brier分数最小化降低校准误差。StaRFM在19个视觉数据集(如ImageNet、Office-Home)上超越基线+3.5%的准确率和降低28%的ECE,在医学基准(如BraTS、ATLAS)上实现超过SAM-FT +4.2%的DSC和4.8mm的HD95,并将跨域差距降低高达20%。该框架是即插即用的,只需要最小的架构更改。
🔬 方法详解
问题定义:现有Foundation Model在实际应用中,面临着由于数据分布差异导致的性能下降问题,具体表现为视觉-语言模型的嵌入偏移和医学图像模型的3D领域偏移。此外,模型预测的置信度与实际准确率不匹配,导致过度自信的错误,严重影响了模型的可靠性。现有方法往往针对特定领域,缺乏通用性。
核心思路:StaRFM的核心思路是通过引入Fisher信息惩罚(FIP)来约束模型的参数学习,使其对分布偏移更加鲁棒,并利用置信度失调惩罚(CMP)来校准模型的预测置信度,使其与实际准确率更加一致。这种双重惩罚机制旨在同时解决分布偏移和置信度失调问题,提高模型的泛化能力和可靠性。
技术框架:StaRFM框架主要包含两个核心模块:FIP模块和CMP模块。FIP模块通过计算Fisher信息矩阵,并对模型参数施加正则化,从而减小模型对训练数据的过度拟合,提高其对新数据的泛化能力。CMP模块则通过最小化Brier分数,来校准模型的预测置信度,使其与实际准确率更加一致。这两个模块可以即插即用地集成到现有的Foundation Model中,无需对模型架构进行大幅修改。
关键创新:StaRFM的关键创新在于其同时解决了Foundation Model在分布偏移下的鲁棒性和校准性问题。FIP模块通过patch-wise正则化扩展到3D,使其能够处理医学图像中的3D领域偏移。CMP模块被重新制定以进行体素级预测,从而能够校准医学图像分割任务中的体素级不确定性。此外,该框架具有通用性,可以应用于不同的Foundation Model和任务。
关键设计:FIP模块的关键设计在于Fisher信息矩阵的计算方式和正则化强度的选择。CMP模块的关键设计在于Brier分数的计算方式和校准目标的设定。论文还推导了PAC-Bayes界限,为FIP和CMP的有效性提供了理论支持。具体的参数设置和损失函数选择需要根据具体的任务和数据集进行调整。
🖼️ 关键图片
📊 实验亮点
StaRFM在19个视觉数据集上取得了显著的性能提升,准确率提高了3.5%,ECE降低了28%。在医学图像分割任务中,StaRFM在BraTS和ATLAS数据集上分别实现了+4.2%的DSC和4.8mm的HD95,显著优于SAM-FT等基线方法。此外,StaRFM还能够有效降低跨域差距,最高可达20%。
🎯 应用场景
StaRFM具有广泛的应用前景,可用于提升计算机视觉和医学图像分析任务中Foundation Model的性能和可靠性。例如,在自动驾驶领域,可以提高模型在不同天气和光照条件下的目标检测和识别能力。在医学图像分析领域,可以提高模型在不同扫描仪和患者群体中的疾病诊断和分割精度,辅助医生进行更准确的临床决策。
📄 摘要(原文)
Foundation models like CLIP and SAM have advanced computer vision and medical imaging via low-shot transfer learning, aiding CADD with limited data. However, their deployment faces two key challenges. \textit{distribution shift} where pre-training and post-training data distributions differ (e.g., due to inter-center image acquisition) and \textit{confidence misalignment}, which leads to overconfident errors. These issues surface differently, vision-language models (e.g., CLIP) suffer from 2D embedding shift (image-text misalignment), while medical models (e.g., SAM) encounter 3D domain shifts (e.g., scanner variation) and voxel-wise calibration need. Existing solutions are domain-specific. We propose \textbf{StaRFM}, a fusion of Fisher information penalty (FIP) and confidence misalignment penalty (CMP) tackling both challenges. It applies FIP, extended to 3D via patch-wise regularization, to reduce embedding shift, and CMP, reformulated for voxel-level predictions, to calibrate segmentation uncertainty. We derive PAC-Bayes bounds. FIP controls generalization via the Fisher-Rao norm, and CMP reduces calibration error via Brier score minimization. StaRFM surpasses baselines by \texttt{+}3.5\% accuracy and 28\% lower ECE on 19 vision datasets (e.g., ImageNet, Office-Home), achieves +4.2\% DSC over SAM-FT and 4.8mm HD95 on medical benchmarks (e.g., BraTS, ATLAS), and reduces cross-domain gaps by up to 20\%. The framework is plug-and-play, requiring minimal architectural changes. Code and models are available at: \href{https://anonymous.4open.science/r/StaRFM-C0CD/}{\textcolor{blue}{\underline{StaRFM}}}