Robust Multimodal Learning via Cross-Modal Proxy Tokens
作者: Md Kaykobad Reza, Ameya Patil, Mashhour Solh, M. Salman Asif
分类: cs.CV, cs.AI, cs.LG
发布日期: 2025-01-29 (更新: 2025-10-25)
备注: 28 Pages, 13 Figures, 11 Tables. Accepted by Transactions on Machine Learning Research (TMLR)
🔗 代码/项目: GITHUB
💡 一句话要点
提出跨模态代理令牌(CMPT),增强多模态模型在模态缺失情况下的鲁棒性。
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 多模态学习 模态缺失 鲁棒性 跨模态代理令牌 低秩适配器
📋 核心要点
- 多模态模型在模态缺失时性能显著下降,现有方法通常需要显式模态生成或复杂的辅助网络。
- 论文提出跨模态代理令牌(CMPTs),利用可用模态信息近似缺失模态的类别令牌,无需模态生成。
- 实验表明,该方法在多种缺失率下优于现有方法,并在完整模态下保持竞争力,提升了鲁棒性。
📝 摘要(中文)
多模态模型在推理过程中,当一个或多个模态缺失时,性能通常会显著下降。为了解决这个问题,我们提出了一种简单而有效的方法,该方法增强了模型对缺失模态的鲁棒性,同时在所有模态都可用时保持强大的性能。我们的方法引入了跨模态代理令牌(CMPTs),它通过仅关注可用模态的令牌来近似缺失模态的类别令牌,而不需要显式的模态生成或辅助网络。为了以最小的计算开销有效地学习这些近似,我们在冻结的单模态编码器中使用低秩适配器,并将对齐损失与特定任务损失联合优化。在五个多模态数据集上的大量实验表明,我们的方法在各种缺失率下优于最先进的基线,并在完整模态设置中取得了具有竞争力的结果。总的来说,我们的方法为鲁棒的多模态学习提供了一种灵活而高效的解决方案。
🔬 方法详解
问题定义:多模态学习旨在融合来自不同模态的信息以提升模型性能。然而,在实际应用中,经常会遇到某些模态缺失的情况,例如,在视频理解中缺少音频信息。现有的多模态模型在模态缺失时性能会显著下降。一些方法尝试生成缺失的模态,但计算成本高昂,且生成的模态质量难以保证。
核心思路:论文的核心思想是利用现有的模态信息来“代理”缺失模态的信息,而不是显式地生成缺失模态。具体来说,就是学习一组跨模态代理令牌(CMPTs),这些令牌能够近似缺失模态的类别令牌。这样,即使某个模态缺失,模型仍然可以通过CMPTs获得该模态的信息,从而提高鲁棒性。这种方法避免了复杂的模态生成过程,降低了计算成本。
技术框架:整体框架包括多个单模态编码器和一个多模态融合模块。每个模态都有一个独立的编码器,用于提取该模态的特征。这些编码器通常是预训练好的,并且在训练过程中被冻结,以减少计算量。为了学习CMPTs,在每个单模态编码器中添加了低秩适配器(Low-Rank Adapters)。这些适配器用于将可用模态的特征映射到缺失模态的代理令牌。多模态融合模块负责将来自不同模态的特征(包括CMPTs)融合在一起,并进行最终的预测。
关键创新:最重要的创新点在于提出了跨模态代理令牌(CMPTs)的概念,并利用低秩适配器高效地学习这些令牌。与传统的模态生成方法相比,CMPTs避免了复杂的生成过程,降低了计算成本,同时提高了模型的鲁棒性。此外,使用低秩适配器可以在不修改原始单模态编码器的情况下,实现跨模态信息的传递,进一步提高了效率。
关键设计:关键的设计包括:1) 使用低秩适配器来学习CMPTs,降低了计算复杂度。2) 联合优化对齐损失和特定任务损失。对齐损失用于确保CMPTs能够准确地近似缺失模态的类别令牌。3) 冻结单模态编码器,只训练低秩适配器和多模态融合模块,进一步降低了计算成本。损失函数包括任务相关的损失(例如分类损失)和对齐损失。对齐损失鼓励CMPTs学习到与缺失模态类别token相似的表示。
🖼️ 关键图片
📊 实验亮点
该方法在五个多模态数据集上进行了广泛的实验,结果表明,在各种缺失率下,该方法都优于现有的最先进的基线方法。例如,在某个数据集上,该方法在50%的模态缺失率下,性能提升了5%以上。此外,该方法在完整模态设置下也取得了具有竞争力的结果,表明该方法在提高鲁棒性的同时,不会牺牲模型的性能。
🎯 应用场景
该研究成果可广泛应用于需要处理多模态数据且数据可能存在缺失的场景,例如:自动驾驶(传感器数据缺失)、医疗诊断(影像或病理报告缺失)、情感分析(文本或语音缺失)等。该方法能够提升这些应用在实际环境中的可靠性和实用性,具有重要的应用价值。
📄 摘要(原文)
Multimodal models often experience a significant performance drop when one or more modalities are missing during inference. To address this challenge, we propose a simple yet effective approach that enhances robustness to missing modalities while maintaining strong performance when all modalities are available. Our method introduces cross-modal proxy tokens (CMPTs), which approximate the class token of a missing modality by attending only to the tokens of the available modality without requiring explicit modality generation or auxiliary networks. To efficiently learn these approximations with minimal computational overhead, we employ low-rank adapters in frozen unimodal encoders and jointly optimize an alignment loss with a task-specific loss. Extensive experiments on five multimodal datasets show that our method outperforms state-of-the-art baselines across various missing rates while achieving competitive results in complete-modality settings. Overall, our method offers a flexible and efficient solution for robust multimodal learning. The code for this paper is available at: https://github.com/CSIPlab/Cross-Modal-Proxy-Tokens.