START: A Generalized State Space Model with Saliency-Driven Token-Aware Transformation
作者: Jintao Guo, Lei Qi, Yinghuan Shi, Yang Gao
分类: cs.CV
发布日期: 2024-10-21 (更新: 2025-01-07)
备注: Accepted by NeurIPS2024. The code is available at https://github.com/lingeringlight/START
🔗 代码/项目: GITHUB
💡 一句话要点
提出基于显著性驱动的Token感知变换状态空间模型START,提升域泛化能力。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 领域泛化 状态空间模型 显著性检测 Token感知变换 Mamba 深度学习 计算机视觉
📋 核心要点
- 现有领域泛化方法依赖CNN易过拟合源域,ViT计算成本高昂,难以兼顾性能与效率。
- 提出START模型,通过显著性驱动的Token感知变换,选择性抑制领域特定特征,减少领域差异。
- START在五个基准测试中超越现有SOTA方法,同时保持线性复杂度,具有高效的泛化能力。
📝 摘要(中文)
领域泛化(DG)旨在通过从多个源域学习,使模型能够泛化到未见过的目标域。现有的DG方法主要依赖于卷积神经网络(CNN),由于其感受野有限,固有地学习纹理偏差,容易过度拟合源域。虽然一些工作引入了基于Transformer的方法(ViT)用于DG,以利用全局感受野,但由于自注意力的二次复杂度,这些方法会产生很高的计算成本。最近,以Mamba为代表的先进状态空间模型(SSM)在监督学习任务中表现出良好的结果,在训练期间实现了序列长度的线性复杂性,并在推理期间实现了快速的类RNN计算。受此启发,我们研究了Mamba模型在领域偏移下的泛化能力,发现SSM中依赖于输入的矩阵可能会累积和放大特定领域的特征,从而阻碍模型泛化。为了解决这个问题,我们提出了一种新的基于SSM的架构,具有基于显著性的token感知变换(即START),它实现了最先进(SOTA)的性能,并为CNN和ViT提供了一个有竞争力的替代方案。我们的START可以选择性地扰动和抑制SSM的输入相关矩阵中显著token的特定领域特征,从而有效地减少不同领域之间的差异。在五个基准上的大量实验表明,START优于现有的SOTA DG方法,并具有高效的线性复杂度。我们的代码可在https://github.com/lingeringlight/START获得。
🔬 方法详解
问题定义:领域泛化(DG)旨在解决模型在未见过的目标域上的泛化问题。现有方法,特别是基于CNN的方法,容易受到纹理偏差的影响,导致在源域上过拟合。而基于Transformer的方法虽然具有全局感受野,但计算复杂度高,难以实际应用。因此,如何在保证模型泛化能力的同时,降低计算成本是一个关键问题。
核心思路:论文的核心思路是利用状态空间模型(SSM)的线性复杂度优势,并在此基础上,通过显著性驱动的Token感知变换,选择性地抑制输入相关矩阵中特定领域的特征。这样既能保持模型的全局感受野,又能避免模型过度关注领域相关的特征,从而提高泛化能力。
技术框架:START模型基于Mamba架构,主要包含以下几个模块:1) 输入嵌入层:将输入数据转换为token嵌入;2) SSM层:利用Mamba进行序列建模;3) 显著性检测模块:计算每个token的显著性得分;4) Token感知变换模块:根据显著性得分,对SSM的输入相关矩阵进行扰动或抑制;5) 输出层:将SSM的输出转换为最终预测结果。整个流程是,输入数据经过嵌入后,通过SSM层进行建模,然后利用显著性检测模块计算每个token的重要性,最后根据重要性对SSM的参数进行调整,以抑制领域特定特征。
关键创新:START的关键创新在于提出了显著性驱动的Token感知变换。与传统的领域泛化方法不同,START不是直接对输入数据进行处理,而是通过调整SSM的内部参数,来抑制领域特定特征。这种方法更加灵活,可以更好地适应不同的领域偏移。此外,START利用显著性检测模块来确定哪些token是重要的,从而可以更加精确地抑制领域特定特征。
关键设计:START的关键设计包括:1) 显著性检测模块:可以使用不同的显著性检测方法,例如梯度积分、注意力权重等。论文中具体使用了哪种方法未知。2) Token感知变换:可以采用不同的变换方式,例如加性扰动、乘性抑制等。具体采用哪种方式未知。3) 损失函数:除了标准的分类损失函数外,还可以添加一些正则化项,以鼓励模型学习更加通用的特征。具体使用了哪些正则化项未知。
🖼️ 关键图片
📊 实验亮点
START在五个领域泛化基准测试中取得了SOTA性能,证明了其有效性。具体性能数据未知,但摘要中强调了START优于现有的SOTA DG方法,并具有高效的线性复杂度。这表明START在性能和效率方面都具有优势,为领域泛化问题提供了一个有竞争力的解决方案。
🎯 应用场景
START模型可应用于各种需要领域泛化的场景,例如图像识别、自然语言处理等。在医疗诊断领域,可以利用START模型,从多个医院的数据中学习,从而提高模型在新的医院数据上的诊断准确率。在自动驾驶领域,可以利用START模型,从不同的城市数据中学习,从而提高模型在新的城市道路上的驾驶安全性。该研究具有重要的实际价值和广泛的应用前景。
📄 摘要(原文)
Domain Generalization (DG) aims to enable models to generalize to unseen target domains by learning from multiple source domains. Existing DG methods primarily rely on convolutional neural networks (CNNs), which inherently learn texture biases due to their limited receptive fields, making them prone to overfitting source domains. While some works have introduced transformer-based methods (ViTs) for DG to leverage the global receptive field, these methods incur high computational costs due to the quadratic complexity of self-attention. Recently, advanced state space models (SSMs), represented by Mamba, have shown promising results in supervised learning tasks by achieving linear complexity in sequence length during training and fast RNN-like computation during inference. Inspired by this, we investigate the generalization ability of the Mamba model under domain shifts and find that input-dependent matrices within SSMs could accumulate and amplify domain-specific features, thus hindering model generalization. To address this issue, we propose a novel SSM-based architecture with saliency-based token-aware transformation (namely START), which achieves state-of-the-art (SOTA) performances and offers a competitive alternative to CNNs and ViTs. Our START can selectively perturb and suppress domain-specific features in salient tokens within the input-dependent matrices of SSMs, thus effectively reducing the discrepancy between different domains. Extensive experiments on five benchmarks demonstrate that START outperforms existing SOTA DG methods with efficient linear complexity. Our code is available at https://github.com/lingeringlight/START.