What Do Learning Dynamics Reveal About Generalization in LLM Reasoning?

📄 arXiv: 2411.07681v2 📥 PDF

作者: Katie Kang, Amrith Setlur, Dibya Ghosh, Jacob Steinhardt, Claire Tomlin, Sergey Levine, Aviral Kumar

分类: cs.LG

发布日期: 2024-11-12 (更新: 2024-11-18)


💡 一句话要点

通过学习动态揭示LLM推理泛化能力:提出预记忆训练准确率指标

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 大型语言模型 推理能力 泛化能力 预记忆训练准确率 数据管理

📋 核心要点

  1. 现有LLM推理能力背后的机制尚不明确,难以区分模型是真正学会推理,还是仅仅记忆了训练数据。
  2. 论文提出“预记忆训练准确率”指标,用于衡量模型在开始记忆训练数据推理步骤前的训练准确率,以此表征模型的泛化能力。
  3. 实验表明,该指标能有效预测测试准确率,并指导数据管理,提升数据效率1.5-2倍,优于传统方法。

📝 摘要(中文)

尽管大型语言模型(LLM)展现出卓越的能力,但其问题解决机制仍然难以捉摸。本文旨在深入理解LLM微调的学习动态如何影响下游泛化能力。分析聚焦于推理任务,其问题结构允许区分记忆(完全复制训练数据中的推理步骤)和性能(最终解决方案的正确性)。研究发现,模型的泛化行为可以通过一个名为“预记忆训练准确率”的训练指标有效表征:即模型样本开始复制训练集中精确推理步骤之前,在训练查询上的准确率。在数据集层面,该指标能够可靠地预测测试准确率,在各种模型(Llama3 8B, Gemma2 9B)、数据集(GSM8k, MATH)和训练配置上实现约0.9或更高的R^2。在每个样本层面,该指标也指示了单个模型预测对训练查询扰动的鲁棒性。通过将模型的学习行为与其泛化能力联系起来,预记忆训练准确率可以指导对训练策略的针对性改进。以数据管理为例,结果表明,优先考虑预记忆准确率低的样本,与独立同分布数据缩放相比,数据效率提高了1.5-2倍,并且优于其他标准数据管理技术。

🔬 方法详解

问题定义:论文旨在解决大型语言模型(LLM)在推理任务中泛化能力评估的问题。现有方法难以区分模型是通过真正的推理能力解决问题,还是仅仅通过记忆训练数据中的推理步骤来获得正确答案。这种区分对于理解和提升LLM的泛化能力至关重要。

核心思路:论文的核心思路是提出一个名为“预记忆训练准确率”的指标,该指标衡量模型在开始记忆训练数据中的推理步骤之前,在训练数据上的准确率。作者认为,在模型开始记忆之前达到的准确率更能反映模型的真实推理能力,因此可以作为泛化能力的有效指标。

技术框架:论文的技术框架主要包括以下几个步骤:1)定义推理任务,并构建包含训练集和测试集的数据集;2)使用LLM在训练集上进行微调;3)在训练过程中,计算“预记忆训练准确率”指标;4)分析“预记忆训练准确率”与测试集准确率之间的关系,验证其作为泛化能力指标的有效性;5)利用“预记忆训练准确率”指导数据管理,例如优先选择预记忆准确率低的样本进行训练。

关键创新:论文最重要的技术创新点在于提出了“预记忆训练准确率”这一指标,该指标能够有效区分LLM的记忆行为和真正的推理能力,从而更好地评估和提升模型的泛化能力。与现有方法相比,该指标更加直接地反映了模型的推理能力,并且可以用于指导训练策略的改进。

关键设计:论文的关键设计包括:1)精确定义“预记忆”的概念,即模型开始复制训练数据中的推理步骤的时间点;2)设计算法来自动检测模型是否开始记忆训练数据;3)选择合适的推理任务和数据集,例如GSM8k和MATH,这些数据集具有明确的推理步骤和答案;4)使用常见的LLM架构,例如Llama3和Gemma2,并在这些模型上验证所提出的指标的有效性。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,预记忆训练准确率能够可靠地预测测试准确率,在Llama3 8B和Gemma2 9B等模型上,以及GSM8k和MATH等数据集上,实现了约0.9或更高的R^2。此外,通过优先考虑预记忆准确率低的样本进行训练,数据效率提高了1.5-2倍,并且优于其他标准数据管理技术。这些结果表明,预记忆训练准确率是一个有效的泛化能力指标,可以用于指导LLM的训练。

🎯 应用场景

该研究成果可应用于提升LLM在各种推理任务中的性能,例如数学问题求解、代码生成、逻辑推理等。通过使用预记忆训练准确率指导数据管理和训练策略,可以提高LLM的泛化能力和数据效率,降低训练成本。此外,该研究也有助于更好地理解LLM的工作机制,为开发更强大的AI系统提供理论基础。

📄 摘要(原文)

Despite the remarkable capabilities of modern large language models (LLMs), the mechanisms behind their problem-solving abilities remain elusive. In this work, we aim to better understand how the learning dynamics of LLM finetuning shapes downstream generalization. Our analysis focuses on reasoning tasks, whose problem structure allows us to distinguish between memorization (the exact replication of reasoning steps from the training data) and performance (the correctness of the final solution). We find that a model's generalization behavior can be effectively characterized by a training metric we call pre-memorization train accuracy: the accuracy of model samples on training queries before they begin to copy the exact reasoning steps from the training set. On the dataset level, this metric is able to reliably predict test accuracy, achieving $R^2$ of around or exceeding 0.9 across various models (Llama3 8, Gemma2 9B), datasets (GSM8k, MATH), and training configurations. On a per-example level, this metric is also indicative of whether individual model predictions are robust to perturbations in the training query. By connecting a model's learning behavior to its generalization, pre-memorization train accuracy can guide targeted improvements to training strategies. We focus on data curation as an example, and show that prioritizing examples with low pre-memorization accuracy leads to 1.5-2x improvements in data efficiency compared to i.i.d. data scaling, and outperforms other standard data curation techniques.