Building Math Agents with Multi-Turn Iterative Preference Learning

📄 arXiv: 2409.02392v2 📥 PDF

作者: Wei Xiong, Chengshuai Shi, Jiaming Shen, Aviv Rosenberg, Zhen Qin, Daniele Calandriello, Misha Khalman, Rishabh Joshi, Bilal Piot, Mohammad Saleh, Chi Jin, Tong Zhang, Tianqi Liu

分类: cs.LG, stat.ML

发布日期: 2024-09-04 (更新: 2025-02-27)

备注: A multi-turn direct preference learning framework for tool-integrated reasoning tasks


💡 一句话要点

提出多轮迭代偏好学习框架,提升数学Agent工具集成推理能力

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 多轮推理 偏好学习 数学Agent 工具集成 代码解释器

📋 核心要点

  1. 现有方法难以有效利用外部工具和多轮推理提升LLM的数学问题求解能力。
  2. 提出多轮直接偏好学习框架,优化轨迹级偏好,适配工具集成和多轮推理。
  3. 实验表明,该框架能显著提升LLM在GSM8K和MATH数据集上的数学问题求解性能。

📝 摘要(中文)

本文提出了一种多轮直接偏好学习框架,旨在提升大型语言模型(LLMs)在数学问题求解中的能力,特别是当模型集成外部工具(如代码解释器)并采用多轮思维链(CoT)推理时。现有直接偏好学习算法主要针对单轮对话任务设计,无法充分应对数学推理任务中多轮推理和外部工具集成的复杂性。该框架通过利用来自代码解释器的反馈,优化轨迹级别的偏好。具体实现包括多轮DPO和多轮KTO。实验结果表明,使用来自GSM8K和MATH数据集的增强提示集训练各种语言模型后,性能得到显著提升。例如,经过监督微调的Gemma-1.1-it-7B模型在GSM8K上的性能从77.5%提高到83.9%,在MATH上的性能从46.1%提高到51.2%。Gemma-2-it-9B模型在GSM8K上从84.1%提高到86.3%,在MATH上从51.0%提高到54.5%。

🔬 方法详解

问题定义:现有方法在提升LLM数学问题求解能力时,主要依赖合成数据生成和监督微调(SFT)。直接偏好学习(DPO)是一种有潜力的替代方案,但现有DPO算法主要针对单轮对话任务设计,无法有效处理数学推理中涉及的多轮交互和外部工具集成,导致性能提升受限。因此,需要一种专门为多轮推理和工具集成设计的偏好学习方法。

核心思路:论文的核心思路是设计一种多轮直接偏好学习框架,该框架能够利用代码解释器等外部工具的反馈,并优化轨迹级别的偏好。通过在多轮交互过程中学习更优的策略,模型可以更好地利用外部工具,并进行更有效的推理。这种方法避免了显式奖励函数的设计,直接从偏好数据中学习策略。

技术框架:该框架包含以下主要步骤:1) 使用增强的提示集生成多轮推理轨迹,包括模型输出和工具反馈。2) 构建偏好数据集,其中包含模型生成的不同轨迹,并根据其质量进行排序。3) 使用多轮DPO或多轮KTO等算法,基于偏好数据集训练语言模型。4) 评估模型在数学问题求解任务上的性能。框架的关键在于如何有效地利用多轮交互过程中的信息,并将其转化为偏好信号。

关键创新:该论文的关键创新在于提出了一个专门为工具集成和多轮推理设计的直接偏好学习框架。与传统的单轮DPO方法不同,该框架能够处理多轮交互过程中的复杂依赖关系,并利用外部工具的反馈来指导模型的学习。此外,该框架还提出了多轮DPO和多轮KTO两种具体的实现方式,为实际应用提供了灵活性。

关键设计:该框架的关键设计包括:1) 如何构建高质量的偏好数据集,例如,可以通过人工评估或自动评估的方式对轨迹进行排序。2) 如何有效地利用代码解释器等外部工具的反馈,例如,可以将工具的输出作为奖励信号或约束条件。3) 如何设计合适的损失函数,以优化轨迹级别的偏好。具体而言,多轮DPO使用pairwise ranking loss,而多轮KTO则使用hinge loss来区分更优和更差的轨迹。

📊 实验亮点

实验结果表明,该框架能够显著提升LLM在GSM8K和MATH数据集上的数学问题求解性能。例如,经过监督微调的Gemma-1.1-it-7B模型在GSM8K上的性能从77.5%提高到83.9%,在MATH上的性能从46.1%提高到51.2%。Gemma-2-it-9B模型在GSM8K上从84.1%提高到86.3%,在MATH上从51.0%提高到54.5%。这些结果表明,多轮直接偏好学习框架能够有效地利用外部工具和多轮推理,提升LLM的数学能力。

🎯 应用场景

该研究成果可应用于开发更强大的数学Agent,提升其在教育、科研、金融等领域的应用价值。例如,可以构建智能辅导系统,帮助学生解决数学难题;也可以用于自动化金融建模,提高投资决策的效率和准确性。此外,该方法还可以推广到其他需要多轮推理和工具集成的任务,如代码生成、科学发现等。

📄 摘要(原文)

Recent studies have shown that large language models' (LLMs) mathematical problem-solving capabilities can be enhanced by integrating external tools, such as code interpreters, and employing multi-turn Chain-of-Thought (CoT) reasoning. While current methods focus on synthetic data generation and Supervised Fine-Tuning (SFT), this paper studies the complementary direct preference learning approach to further improve model performance. However, existing direct preference learning algorithms are originally designed for the single-turn chat task, and do not fully address the complexities of multi-turn reasoning and external tool integration required for tool-integrated mathematical reasoning tasks. To fill in this gap, we introduce a multi-turn direct preference learning framework, tailored for this context, that leverages feedback from code interpreters and optimizes trajectory-level preferences. This framework includes multi-turn DPO and multi-turn KTO as specific implementations. The effectiveness of our framework is validated through training of various language models using an augmented prompt set from the GSM8K and MATH datasets. Our results demonstrate substantial improvements: a supervised fine-tuned Gemma-1.1-it-7B model's performance increased from 77.5% to 83.9% on GSM8K and from 46.1% to 51.2% on MATH. Similarly, a Gemma-2-it-9B model improved from 84.1% to 86.3% on GSM8K and from 51.0% to 54.5% on MATH.