Adaptive Constraint Propagation: Scaling Structured Inference for Large Language Models via Meta-Reinforcement Learning
作者: Ibne Farabi Shihab, Sanjeda Akter, Anuj Sharma
分类: cs.CL
发布日期: 2025-12-31 (更新: 2026-01-25)
💡 一句话要点
提出MetaJuLS,通过元强化学习加速大语言模型中的结构化推理。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 元强化学习 结构化推理 约束传播 图注意力网络 大语言模型 绿色AI 跨领域适应
📋 核心要点
- 现有大语言模型在结构化推理中面临复杂约束,需要耗时的任务特定训练。
- MetaJuLS通过元强化学习学习通用约束传播策略,实现跨语言和任务的快速适应。
- 实验表明,MetaJuLS在速度上提升1.5-2.0倍,精度与SOTA解析器持平,并能快速适应新任务。
📝 摘要(中文)
大型语言模型越来越多地需要结构化推理,例如JSON模式强制执行和多语言解析,这些任务要求输出满足复杂的约束。我们提出MetaJuLS,一种元强化学习方法,它学习通用的约束传播策略,该策略适用于多种语言和任务,而无需针对特定任务进行重新训练。通过将结构化推理公式化为自适应约束传播,并使用元学习训练图注意力网络,MetaJuLS实现了比GPU优化的基线方法快1.5-2.0倍的速度,同时保持了最先进解析器0.2%的精度。在跨10种语言的通用依存关系和LLM约束生成(LogicBench,GSM8K-Constrained)上,MetaJuLS展示了快速的跨领域适应性:在英语解析上训练的策略可以通过5-10个梯度步骤(5-15秒)适应新的语言和任务,而不需要数小时的特定任务训练。机制分析表明,该策略发现了类似人类的解析策略(easy-first)和新颖的非直观启发式方法。通过减少LLM部署中的传播步骤,MetaJuLS通过直接减少推理碳足迹为绿色AI做出贡献。
🔬 方法详解
问题定义:现有的大语言模型在处理需要满足复杂约束的结构化推理任务时,例如JSON模式强制执行和多语言解析,面临着计算成本高昂和需要大量特定任务训练的问题。现有的方法通常需要针对每个任务进行重新训练,这使得它们难以扩展到新的语言和任务。
核心思路:MetaJuLS的核心思路是将结构化推理问题建模为自适应约束传播问题,并利用元强化学习来学习一个通用的约束传播策略。该策略能够跨语言和任务进行泛化,从而避免了针对每个任务进行重新训练的需要。通过学习如何有效地传播约束,MetaJuLS能够减少推理所需的步骤,从而提高速度并降低计算成本。
技术框架:MetaJuLS的技术框架主要包括以下几个模块:1)图表示模块:将输入数据(例如,句子或JSON模式)表示为一个图,其中节点表示单词或JSON元素,边表示它们之间的关系。2)图注意力网络(GAT):使用GAT来学习节点之间的依赖关系,并传播约束信息。3)元强化学习模块:使用元强化学习来训练GAT,使其能够学习通用的约束传播策略。4)自适应约束传播模块:使用学习到的策略来动态地选择要传播的约束,从而减少推理所需的步骤。
关键创新:MetaJuLS的关键创新在于它使用元强化学习来学习通用的约束传播策略。与传统的需要针对每个任务进行重新训练的方法不同,MetaJuLS能够通过少量的梯度步骤快速适应新的语言和任务。此外,MetaJuLS还发现了一些类似人类的解析策略和新颖的非直观启发式方法,这表明该方法具有很强的学习能力。
关键设计:MetaJuLS的关键设计包括:1)使用图注意力网络来学习节点之间的依赖关系。2)使用元强化学习来训练GAT,使其能够学习通用的约束传播策略。3)设计一个奖励函数,鼓励模型减少推理所需的步骤,同时保持精度。4)使用少量的梯度步骤来适应新的语言和任务。
🖼️ 关键图片
📊 实验亮点
MetaJuLS在通用依存关系解析和LLM约束生成任务上取得了显著的性能提升。在速度方面,MetaJuLS比GPU优化的基线方法快1.5-2.0倍,同时保持了最先进解析器0.2%的精度。在跨领域适应性方面,MetaJuLS仅需5-10个梯度步骤(5-15秒)即可适应新的语言和任务,而不需要数小时的特定任务训练。
🎯 应用场景
MetaJuLS可应用于各种需要结构化推理的大语言模型应用场景,例如:多语言机器翻译、代码生成、JSON模式验证、知识图谱推理等。通过加速推理过程,MetaJuLS能够降低计算成本,提高用户体验,并减少大语言模型的碳排放,促进绿色AI的发展。
📄 摘要(原文)
Large language models increasingly require structured inference, from JSON schema enforcement to multi-lingual parsing, where outputs must satisfy complex constraints. We introduce MetaJuLS, a meta-reinforcement learning approach that learns universal constraint propagation policies applicable across languages and tasks without task-specific retraining. By formulating structured inference as adaptive constraint propagation and training a Graph Attention Network with meta-learning, MetaJuLS achieves 1.5--2.0$\times$ speedups over GPU-optimized baselines while maintaining within 0.2\% accuracy of state-of-the-art parsers. On Universal Dependencies across 10 languages and LLM-constrained generation (LogicBench, GSM8K-Constrained), MetaJuLS demonstrates rapid cross-domain adaptation: a policy trained on English parsing adapts to new languages and tasks with 5--10 gradient steps (5--15 seconds) rather than requiring hours of task-specific training. Mechanistic analysis reveals the policy discovers human-like parsing strategies (easy-first) and novel non-intuitive heuristics. By reducing propagation steps in LLM deployments, MetaJuLS contributes to Green AI by directly reducing inference carbon footprint.