CRoFT: Robust Fine-Tuning with Concurrent Optimization for OOD Generalization and Open-Set OOD Detection

📄 arXiv: 2405.16417v1 📥 PDF

作者: Lin Zhu, Yifeng Yang, Qinying Gu, Xinbing Wang, Chenghu Zhou, Nanyang Ye

分类: cs.CV

发布日期: 2024-05-26

备注: Accepted by ICML2024

🔗 代码/项目: GITHUB


💡 一句话要点

提出CRoFT框架,通过并发优化提升VL-PTM在OOD泛化和开集OOD检测中的鲁棒性。

🎯 匹配领域: 支柱三:空间感知与语义 (Perception & Semantics)

关键词: 视觉语言模型 OOD泛化 开集OOD检测 鲁棒微调 领域泛化

📋 核心要点

  1. 现有VL-PTM微调易损失泛化能力,难以应对真实场景中的协变量和语义偏移,对OOD数据的处理能力不足。
  2. CRoFT框架通过最小化训练数据上能量分数的梯度幅度,实现领域一致的分类损失Hessian矩阵,从而提升OOD泛化能力。
  3. 实验结果表明,CRoFT框架在OOD泛化和开集OOD检测任务上均表现出优越性,验证了该方法的有效性。

📝 摘要(中文)

近年来,视觉-语言预训练模型(VL-PTM)在开放词汇任务中取得了显著成功。然而,下游应用通常需要对VL-PTM进行进一步的微调,这可能会扭曲其通用知识,并损害其处理分布偏移的能力。在现实场景中,机器学习系统不可避免地会遇到协变量偏移(例如,图像风格的变化)和语义偏移(例如,测试时未见过的类别)。这突出了增强VL-PTM在协变量偏移上的OOD泛化能力,并同时检测语义偏移的未见类的重要性。因此,一个关键但未被充分探索的问题出现了:如何在微调过程中提高VL-PTM对闭集OOD数据的泛化能力,同时有效地检测开集未见类?在本文中,我们提出了一种新的OOD检测目标函数,该函数也有助于提高OOD泛化能力。我们表明,最小化训练数据上能量分数的梯度幅度会导致领域一致的分类损失Hessian矩阵,这是理论分析揭示的OOD泛化的一个强指标。基于这一发现,我们开发了一个统一的微调框架,允许同时优化这两个任务。大量的实验证明了我们方法的优越性。代码可在https://github.com/LinLLLL/CRoFT 获取。

🔬 方法详解

问题定义:论文旨在解决视觉-语言预训练模型(VL-PTM)在微调后,其泛化能力下降,难以同时应对协变量偏移和语义偏移的问题。现有方法通常专注于提升模型在特定分布上的性能,而忽略了模型在未知分布上的鲁棒性,尤其是在开放集OOD检测任务中,性能表现不佳。

核心思路:论文的核心思路是设计一种新的目标函数,该函数不仅能够优化模型的分类性能,还能提高模型对OOD数据的泛化能力和检测能力。通过最小化训练数据上能量分数的梯度幅度,使得模型学习到领域一致的Hessian矩阵,从而提升模型的泛化能力。同时,利用能量分数进行OOD检测,区分已知类和未知类。

技术框架:CRoFT框架是一个统一的微调框架,它允许同时优化OOD泛化和开集OOD检测两个任务。该框架主要包括以下几个模块:1) VL-PTM backbone:使用预训练的视觉-语言模型作为特征提取器。2) 分类器:将提取的特征映射到类别标签。3) 能量函数:用于评估输入样本的置信度,并用于OOD检测。4) 优化器:用于最小化目标函数,包括分类损失和能量梯度损失。

关键创新:该论文的关键创新在于提出了一种新的OOD检测目标函数,该函数能够同时提高OOD泛化能力。通过理论分析,论文证明了最小化训练数据上能量分数的梯度幅度会导致领域一致的分类损失Hessian矩阵,这是OOD泛化的一个强指标。这种方法不同于以往仅仅关注分类性能的微调方法,而是更加注重模型的鲁棒性和泛化能力。

关键设计:CRoFT框架的关键设计包括:1) 能量函数的选择:论文选择了一种基于softmax输出的能量函数,该函数能够有效地评估输入样本的置信度。2) 能量梯度损失:论文设计了一种能量梯度损失,用于最小化训练数据上能量分数的梯度幅度。3) 并发优化:论文采用并发优化策略,同时优化分类损失和能量梯度损失,从而实现OOD泛化和开集OOD检测的协同提升。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

论文通过大量实验验证了CRoFT框架的有效性。实验结果表明,CRoFT框架在多个OOD数据集上均取得了显著的性能提升,尤其是在开集OOD检测任务中,CRoFT框架能够有效地检测出未知类别,并保持较高的分类准确率。相较于基线方法,CRoFT框架在OOD泛化和开集OOD检测性能上均有明显优势。

🎯 应用场景

CRoFT框架可应用于各种需要处理分布偏移和未知类别的视觉-语言任务,例如图像分类、目标检测、图像检索等。在自动驾驶、医疗诊断、智能安防等领域,该框架能够提高系统的鲁棒性和可靠性,降低误判风险,具有重要的实际应用价值。

📄 摘要(原文)

Recent vision-language pre-trained models (VL-PTMs) have shown remarkable success in open-vocabulary tasks. However, downstream use cases often involve further fine-tuning of VL-PTMs, which may distort their general knowledge and impair their ability to handle distribution shifts. In real-world scenarios, machine learning systems inevitably encounter both covariate shifts (e.g., changes in image styles) and semantic shifts (e.g., test-time unseen classes). This highlights the importance of enhancing out-of-distribution (OOD) generalization on covariate shifts and simultaneously detecting semantic-shifted unseen classes. Thus a critical but underexplored question arises: How to improve VL-PTMs' generalization ability to closed-set OOD data, while effectively detecting open-set unseen classes during fine-tuning? In this paper, we propose a novel objective function of OOD detection that also serves to improve OOD generalization. We show that minimizing the gradient magnitude of energy scores on training data leads to domain-consistent Hessians of classification loss, a strong indicator for OOD generalization revealed by theoretical analysis. Based on this finding, we have developed a unified fine-tuning framework that allows for concurrent optimization of both tasks. Extensive experiments have demonstrated the superiority of our method. The code is available at https://github.com/LinLLLL/CRoFT.