Improve Vision Language Model Chain-of-thought Reasoning

📄 arXiv: 2410.16198v1 📥 PDF

作者: Ruohong Zhang, Bowen Zhang, Yanghao Li, Haotian Zhang, Zhiqing Sun, Zhe Gan, Yinfei Yang, Ruoming Pang, Yiming Yang

分类: cs.AI, cs.CV

发布日期: 2024-10-21

备注: 10 pages + appendix


💡 一句话要点

提出基于GPT-4o蒸馏和强化学习的VLM链式推理优化方法

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 视觉语言模型 链式思考 知识蒸馏 强化学习 推理优化 GPT-4o 直接偏好优化

📋 核心要点

  1. 现有VLM训练依赖短答案数据集,缺乏详细推理过程,导致CoT推理能力不足。
  2. 利用GPT-4o模型蒸馏生成详细的推理链,并微调VLM,提升CoT推理性能。
  3. 通过强化学习,基于正确/错误推理链对,使用DPO算法进一步优化模型推理能力。

📝 摘要(中文)

视觉语言模型(VLM)中的链式思考(CoT)推理对于提高模型的可解释性和可信度至关重要。然而,现有的训练方法缺乏鲁棒的CoT推理数据,依赖于以最少理由的简短注释为主的数据集。本文表明,在简短答案上训练VLM不能很好地泛化到需要更详细响应的推理任务。为了解决这个问题,我们提出了一种双重方法。首先,我们从GPT-4o模型中提取理由来丰富训练数据并微调VLM,从而提高其CoT性能。其次,我们应用强化学习来进一步校准推理质量。具体来说,我们通过比较模型生成的推理链的预测与带注释的简短答案,构建正(正确)和负(不正确)的推理链对。使用这种成对数据,我们应用直接偏好优化算法来改进模型的推理能力。我们的实验表明,在基准数据集上CoT推理能力得到了显著提高,并且更好地泛化到直接答案预测。这项工作强调了在训练中加入详细理由以及利用强化学习来加强VLM的推理能力的重要性。

🔬 方法详解

问题定义:现有视觉语言模型在链式思考(CoT)推理方面表现不足,主要原因是训练数据集中缺乏详细的推理过程,大多是简短的答案。这导致模型难以学习复杂的推理逻辑,泛化能力受限,尤其是在需要详细解释的任务中表现不佳。

核心思路:本文的核心思路是通过引入更丰富的推理数据和强化学习方法来提升VLM的CoT推理能力。具体来说,首先利用强大的GPT-4o模型生成详细的推理链,作为VLM的训练数据,弥补现有数据集的不足。然后,使用强化学习方法,根据模型生成的推理链的正确性进行奖励或惩罚,从而引导模型学习更合理的推理过程。

技术框架:整体框架包含两个主要阶段:1) 基于GPT-4o的推理链蒸馏:使用GPT-4o模型对现有数据集进行推理,生成详细的推理链,构建新的训练数据集。2) 基于强化学习的推理优化:使用Direct Preference Optimization (DPO) 算法,根据模型生成的推理链的正确性进行优化。具体流程是,首先VLM生成推理链,然后根据推理链的答案与标准答案的匹配程度,构建正负样本对,最后使用DPO算法优化模型。

关键创新:该论文的关键创新在于结合了知识蒸馏和强化学习来提升VLM的CoT推理能力。传统的知识蒸馏方法通常只关注答案的正确性,而忽略了推理过程。本文通过蒸馏GPT-4o的推理链,使得VLM能够学习到更详细的推理过程。此外,使用强化学习方法,根据推理链的正确性进行优化,进一步提升了模型的推理能力。

关键设计:在知识蒸馏阶段,使用了GPT-4o模型生成推理链,并将其作为VLM的训练数据。在强化学习阶段,使用了Direct Preference Optimization (DPO) 算法,该算法可以直接优化模型的策略,而无需显式地定义奖励函数。DPO算法的关键在于构建正负样本对,本文通过比较模型生成的推理链的答案与标准答案的匹配程度来构建正负样本对。损失函数是DPO的标准损失函数,用于优化模型的策略。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,该方法在多个基准数据集上显著提升了VLM的CoT推理能力,并且在直接答案预测方面也取得了更好的泛化性能。具体来说,相比于基线模型,该方法在某些数据集上取得了超过10%的性能提升。此外,实验还验证了GPT-4o蒸馏和强化学习的有效性,证明了引入详细推理过程和优化推理策略的重要性。

🎯 应用场景

该研究成果可广泛应用于需要可解释性和可靠性的视觉语言任务中,例如医疗诊断、自动驾驶、智能客服等领域。通过提升VLM的推理能力,可以使模型在复杂场景下做出更准确的决策,并提供合理的解释,从而提高用户信任度和应用价值。未来,该方法可以扩展到其他多模态任务中,例如视频理解、语音识别等。

📄 摘要(原文)

Chain-of-thought (CoT) reasoning in vision language models (VLMs) is crucial for improving interpretability and trustworthiness. However, current training recipes lack robust CoT reasoning data, relying on datasets dominated by short annotations with minimal rationales. In this work, we show that training VLM on short answers does not generalize well to reasoning tasks that require more detailed responses. To address this, we propose a two-fold approach. First, we distill rationales from GPT-4o model to enrich the training data and fine-tune VLMs, boosting their CoT performance. Second, we apply reinforcement learning to further calibrate reasoning quality. Specifically, we construct positive (correct) and negative (incorrect) pairs of model-generated reasoning chains, by comparing their predictions with annotated short answers. Using this pairwise data, we apply the Direct Preference Optimization algorithm to refine the model's reasoning abilities. Our experiments demonstrate significant improvements in CoT reasoning on benchmark datasets and better generalization to direct answer prediction as well. This work emphasizes the importance of incorporating detailed rationales in training and leveraging reinforcement learning to strengthen the reasoning capabilities of VLMs.