Can Memory-Augmented Language Models Generalize on Reasoning-in-a-Haystack Tasks?
作者: Payel Das, Ching-Yun Ko, Sihui Dai, Georgios Kollias, Subhajit Chaudhury, Aurelie Lozano
分类: cs.CL, cs.LG
发布日期: 2025-03-10
💡 一句话要点
提出MemReasoner,增强LLM在复杂推理任务中的泛化能力
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 记忆增强 语言模型 推理任务 长上下文 弱监督 泛化能力 多跳推理 注意力机制
📋 核心要点
- 大型语言模型在长上下文推理中表现出局限性,难以有效利用上下文信息进行复杂推理。
- MemReasoner通过学习上下文中事实的相对顺序,并允许跳过不相关信息,从而增强模型推理能力。
- 实验表明,MemReasoner在合成多跳推理任务中表现出强大的泛化能力,尤其是在弱监督条件下。
📝 摘要(中文)
大型语言模型在推理任务中常常表现出脆弱性,尤其是在对上下文执行长链推理时。本文提出了一种新的、简单的记忆增强型LLM架构MemReasoner,其中记忆学习上下文中事实的相对顺序,并能够跳过它们,而解码器选择性地关注记忆。MemReasoner采用端到端的方式进行训练,并可选择使用不同程度的支持事实监督。我们将MemReasoner与现有的记忆增强型Transformer模型和状态空间模型在两个不同的合成多跳推理任务上进行了训练。在各种具有挑战性的场景下进行的实验,包括存在长干扰文本或测试集中目标答案的变化,表明MemReasoner在单跳和双跳任务上都具有很强的泛化能力。MemReasoner的泛化能力是通过使用无或弱支持事实监督来实现的(分别使用0%和1%的支持事实进行单跳和双跳任务)。相比之下,基线模型总体上难以泛化,并且从完全支持事实监督中获益甚微。结果突出了显式记忆机制与额外的弱监督相结合对于提高大型语言模型在推理任务中的上下文处理能力的重要性。
🔬 方法详解
问题定义:大型语言模型在处理需要长程依赖和复杂推理的任务时,容易受到上下文长度的限制,难以有效利用所有相关信息。现有的方法,如直接使用Transformer或增加上下文长度,往往无法很好地解决这个问题,尤其是在存在大量干扰信息的情况下。
核心思路:MemReasoner的核心思路是引入一个外部记忆模块,该模块能够学习上下文中事实的相对顺序,并允许模型在推理过程中跳过不相关的信息。通过这种方式,模型可以更有效地关注关键信息,从而提高推理的准确性和效率。
技术框架:MemReasoner的整体架构包括一个语言模型(如Transformer)、一个记忆模块和一个解码器。首先,语言模型对输入文本进行编码,并将编码后的信息存储到记忆模块中。记忆模块学习事实的相对顺序。然后,解码器选择性地从记忆模块中检索相关信息,并生成最终的答案。整个过程采用端到端的方式进行训练。
关键创新:MemReasoner的关键创新在于其记忆模块的设计,该模块不仅存储了上下文信息,还学习了这些信息的相对顺序。这种设计使得模型能够更好地理解上下文的结构,并更有效地检索相关信息。此外,模型还采用了弱监督的方式进行训练,进一步提高了其泛化能力。
关键设计:MemReasoner使用Transformer作为基础语言模型。记忆模块的具体实现方式未知,但其核心功能是学习事实的相对顺序。解码器使用注意力机制从记忆模块中检索信息。损失函数包括语言模型损失和辅助损失,用于指导记忆模块的学习。弱监督信号通过少量支持事实提供,用于指导模型关注关键信息。
🖼️ 关键图片
📊 实验亮点
实验结果表明,MemReasoner在单跳和双跳推理任务上都表现出强大的泛化能力,尤其是在弱监督条件下。与基线模型相比,MemReasoner能够更好地处理长上下文和干扰信息,并且能够更有效地利用支持事实进行推理。在某些情况下,MemReasoner仅使用1%的支持事实就能达到与基线模型使用全部支持事实相当的性能。
🎯 应用场景
MemReasoner具有广泛的应用前景,例如在问答系统、信息检索、对话系统等领域。它可以帮助模型更好地理解长文本,并从中提取关键信息进行推理。此外,该模型还可以应用于需要处理大量干扰信息的场景,例如从复杂的文档中提取关键信息。
📄 摘要(原文)
Large language models often expose their brittleness in reasoning tasks, especially while executing long chains of reasoning over context. We propose MemReasoner, a new and simple memory-augmented LLM architecture, in which the memory learns the relative order of facts in context, and enables hopping over them, while the decoder selectively attends to the memory. MemReasoner is trained end-to-end, with optional supporting fact supervision of varying degrees. We train MemReasoner, along with existing memory-augmented transformer models and a state-space model, on two distinct synthetic multi-hop reasoning tasks. Experiments performed under a variety of challenging scenarios, including the presence of long distractor text or target answer changes in test set, show strong generalization of MemReasoner on both single- and two-hop tasks. This generalization of MemReasoner is achieved using none-to-weak supporting fact supervision (using none and 1\% of supporting facts for one- and two-hop tasks, respectively). In contrast, baseline models overall struggle to generalize and benefit far less from using full supporting fact supervision. The results highlight the importance of explicit memory mechanisms, combined with additional weak supervision, for improving large language model's context processing ability toward reasoning tasks.