Beyond In-Distribution Success: Scaling Curves of CoT Granularity for Language Model Generalization

📄 arXiv: 2502.18273v1 📥 PDF

作者: Ru Wang, Wei Huang, Selena Song, Haoyu Zhang, Yusuke Iwasawa, Yutaka Matsuo, Jiaxian Guo

分类: cs.CL

发布日期: 2025-02-25


💡 一句话要点

研究表明,细粒度CoT数据能显著提升语言模型在复杂任务上的泛化能力。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 思维链推理 语言模型泛化 分布偏移 复合任务 Transformer CoT粒度 样本效率

📋 核心要点

  1. 现有QA模型在同分布数据上表现优异,但在分布偏移下泛化能力严重不足,尤其是在复杂任务中。
  2. 论文提出利用细粒度的思维链(CoT)数据训练语言模型,促使其学习更通用的推理模式,提升泛化能力。
  3. 实验表明,细粒度CoT数据能显著提升模型在分布偏移下的性能,且CoT具有更高的样本效率。

📝 摘要(中文)

本研究探讨了思维链(Chain-of-Thought, CoT)推理作为增强基于Transformer的语言模型(LM)在分布偏移下对新复合任务泛化能力的一种手段。通过对多个复合任务进行受控实验,我们揭示了三个关键见解:(1)虽然QA训练的模型在同分布数据上实现了接近完美的准确率,但即使使用超过10000k的训练样本,其OOD性能也会急剧下降;(2)CoT数据的粒度与泛化性能密切相关;更细粒度的CoT数据能带来更好的泛化效果;(3)CoT表现出卓越的样本效率,用更少(甚至80%)的数据就能匹配QA的性能。理论上,我们证明了复合任务本质上允许Q-A数据中存在与真实推理原则不一致的捷径,而CoT强制内化有效的依赖结构,因此可以实现更好的泛化。此外,我们表明Transformer的位置嵌入可以通过强调长CoT序列中的子任务条件复现来放大泛化能力。我们的理论和经验分析共同为CoT推理作为一种关键的训练范式提供了有力的证据,该范式能够使LM在现实世界中复合任务的分布偏移下实现泛化。

🔬 方法详解

问题定义:论文旨在解决语言模型在面对分布偏移时,对复杂复合任务泛化能力不足的问题。现有QA模型虽然在同分布数据上表现良好,但容易学习到数据中的捷径,导致在新的、分布不同的任务上性能急剧下降。这种现象表明模型缺乏真正的推理能力,无法适应真实世界的复杂场景。

核心思路:论文的核心思路是利用思维链(CoT)推理来提升模型的泛化能力。CoT通过提供更详细的推理步骤,迫使模型学习任务的内在依赖关系,而不是仅仅依赖输入和输出之间的表面相关性。更细粒度的CoT数据能够提供更丰富的推理信息,从而帮助模型更好地理解任务的结构和逻辑。

技术框架:论文采用标准的Transformer架构作为基础语言模型。训练过程主要分为两个阶段:首先,使用QA数据训练模型,使其具备基本的问答能力;然后,使用CoT数据对模型进行微调,使其学习推理过程。关键在于CoT数据的构建,需要保证其粒度足够细,能够清晰地展示任务的分解和推理步骤。

关键创新:论文的关键创新在于发现了CoT数据粒度与泛化性能之间的强相关性。更细粒度的CoT数据能够提供更丰富的推理信息,从而帮助模型更好地理解任务的结构和逻辑,提升模型的泛化能力。此外,论文还从理论上解释了CoT能够提升泛化能力的原因,并探讨了Transformer的位置嵌入在CoT推理中的作用。

关键设计:论文的关键设计包括:1) CoT数据的粒度控制,通过人工标注或程序生成的方式,确保CoT数据能够清晰地展示任务的分解和推理步骤;2) 损失函数的设计,可以使用标准的交叉熵损失函数,也可以根据任务的特点进行调整,例如,对推理步骤的正确性进行加权;3) Transformer的位置嵌入,通过调整位置嵌入的参数,可以强调CoT序列中子任务条件复现,从而提升模型的泛化能力。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,使用细粒度CoT数据训练的模型在分布偏移下的性能显著优于使用QA数据训练的模型。例如,在某些复合任务上,QA训练的模型OOD准确率接近于0,而使用CoT训练的模型可以达到50%以上。此外,CoT还表现出卓越的样本效率,用更少(甚至80%)的数据就能匹配QA的性能。

🎯 应用场景

该研究成果可应用于各种需要语言模型具备强大泛化能力的场景,例如智能客服、自动驾驶、医疗诊断等。通过使用CoT训练,可以使模型更好地理解用户的意图,处理复杂的任务,并在新的、未知的环境中保持良好的性能。这有助于提高人工智能系统的可靠性和实用性。

📄 摘要(原文)

Generalization to novel compound tasks under distribution shift is important for deploying transformer-based language models (LMs). This work investigates Chain-of-Thought (CoT) reasoning as a means to enhance OOD generalization. Through controlled experiments across several compound tasks, we reveal three key insights: (1) While QA-trained models achieve near-perfect in-distribution accuracy, their OOD performance degrades catastrophically, even with 10000k+ training examples; (2) the granularity of CoT data strongly correlates with generalization performance; finer-grained CoT data leads to better generalization; (3) CoT exhibits remarkable sample efficiency, matching QA performance with much less (even 80%) data. Theoretically, we demonstrate that compound tasks inherently permit shortcuts in Q-A data that misalign with true reasoning principles, while CoT forces internalization of valid dependency structures, and thus can achieve better generalization. Further, we show that transformer positional embeddings can amplify generalization by emphasizing subtask condition recurrence in long CoT sequences. Our combined theoretical and empirical analysis provides compelling evidence for CoT reasoning as a crucial training paradigm for enabling LM generalization under real-world distributional shifts for compound tasks.