$\texttt{LM}^\texttt{2}$: A Simple Society of Language Models Solves Complex Reasoning

📄 arXiv: 2404.02255v1 📥 PDF

作者: Gurusha Juneja, Subhabrata Dutta, Tanmoy Chakraborty

分类: cs.CL, cs.AI

发布日期: 2024-04-02

🔗 代码/项目: GITHUB


💡 一句话要点

提出LM²以解决复杂推理问题

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 大型语言模型 复杂推理 模块化设计 推理能力 策略学习

📋 核心要点

  1. 现有方法在处理复杂的多步骤推理时缺乏有效的协调机制,导致推理能力不足。
  2. LM²通过将分解、求解和验证模块化为三个不同的语言模型,提升了推理的准确性和效率。
  3. 实验结果显示,LM²在MATH、JEEBench和MedQA问题上分别比最佳基线提高了8.1%、7.71%和9.7%。

📝 摘要(中文)

尽管大型语言模型(LLMs)展现出新兴的推理能力,但在处理复杂的多步骤推理时常常失去方向。现有研究表明,通过将原始问题分解为多个子问题的方式可以增强LLM的推理能力。然而,这些技术未能有效协调分解器和求解器模块之间的关系。本文提出LM²,模块化地将分解、解决和验证过程分为三个不同的语言模型。分解器识别解决问题所需的关键概念,并根据推理要求生成逐步的子问题。求解器生成子问题的解决方案,随后由验证器模块进行检查。根据验证器的反馈,构建推理上下文。大量实验表明,LM²在领域内外的推理问题上优于现有方法。

🔬 方法详解

问题定义:本文旨在解决大型语言模型在复杂多步骤推理中的协调不足问题。现有方法未能有效跟踪分解器与求解器之间的关系,导致推理能力下降。

核心思路:LM²的核心思想是将推理过程模块化,分别使用分解器、求解器和验证器来处理问题。这种设计允许每个模块专注于其特定任务,从而提高整体推理的准确性和效率。

技术框架:LM²的整体架构包括三个主要模块:分解器负责识别关键概念并生成子问题;求解器生成子问题的解决方案;验证器检查解决方案的正确性,并根据反馈调整推理上下文。

关键创新:LM²的主要创新在于模块化设计,使得分解器和求解器之间的协调更加高效。这与现有方法的单一模型设计形成鲜明对比,显著提升了推理能力。

关键设计:在训练过程中,LM²采用策略学习来协调各个模块的工作。具体的参数设置和损失函数设计未在摘要中详细说明,需参考论文的具体内容。

📊 实验亮点

LM²在多个推理任务上的实验结果显示,其性能显著优于现有方法。在MATH、JEEBench和MedQA问题上,LM²分别提高了8.1%、7.71%和9.7%,展现出其在复杂推理场景中的有效性和优势。

🎯 应用场景

LM²的研究成果在教育、医疗和科学研究等多个领域具有广泛的应用潜力。通过提高复杂推理的准确性,LM²可以帮助学生更好地理解数学和科学问题,也可以在医疗诊断和决策支持系统中提供更可靠的推理能力,推动相关领域的发展。

📄 摘要(原文)

Despite demonstrating emergent reasoning abilities, Large Language Models (LLMS) often lose track of complex, multi-step reasoning. Existing studies show that providing guidance via decomposing the original question into multiple subproblems elicits more robustness in LLM reasoning -- a decomposer generates the subproblems, and a solver solves each of these subproblems. However, these techniques fail to accommodate coordination between the decomposer and the solver modules (either in a single model or different specialized ones) -- the decomposer does not keep track of the ability of the solver to follow the decomposed reasoning. In this paper, we propose LM2 to address these challenges. LM2 modularizes the decomposition, solution, and verification into three different language models. The decomposer module identifies the key concepts necessary to solve the problem and generates step-by-step subquestions according to the reasoning requirement. The solver model generates the solution to the subproblems that are then checked by the verifier module; depending upon the feedback from the verifier, the reasoning context is constructed using the subproblems and the solutions. These models are trained to coordinate using policy learning. Exhaustive experimentation suggests the superiority of LM2 over existing methods on in- and out-domain reasoning problems, outperforming the best baselines by $8.1\%$ on MATH, $7.71\%$ on JEEBench, and $9.7\%$ on MedQA problems (code available at https://github.com/LCS2-IIITD/Language_Model_Multiplex).