Data Mixing Agent: Learning to Re-weight Domains for Continual Pre-training
作者: Kailai Yang, Xiao Liu, Lei Ji, Hao Li, Yeyun Gong, Peng Cheng, Mao Yang
分类: cs.LG, cs.AI, cs.CL
发布日期: 2025-07-21
💡 一句话要点
提出数据混合Agent,通过强化学习自动学习领域重加权策略,提升持续预训练效果。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 持续预训练 领域重加权 强化学习 数据混合 灾难性遗忘
📋 核心要点
- 持续预训练面临灾难性遗忘问题,现有领域重加权方法依赖人工启发式策略,缺乏通用性和自适应性。
- 提出数据混合Agent,利用强化学习自动学习领域重加权策略,无需人工干预,实现端到端优化。
- 实验表明,该Agent在数学推理和代码生成等任务上,显著提升了持续预训练的性能,并具有良好的泛化能力。
📝 摘要(中文)
在小规模特定任务数据上进行持续预训练是提升大型语言模型在新目标领域性能的有效方法,但同时也存在灾难性遗忘原始能力的风险。一种常见的解决方案是对来自源领域和目标领域的训练数据混合进行领域重加权,以实现平衡的性能。以往的领域重加权策略依赖于人工指定,基于人类直觉或经验结果的启发式方法。本文证明了更通用的启发式方法可以通过参数化来实现,并提出了数据混合Agent,这是第一个基于模型的端到端框架,可以学习对领域进行重加权。该Agent通过强化学习,在大规模数据混合轨迹上学习可泛化的启发式方法,并从评估环境中获得相应的反馈。在数学推理的持续预训练实验表明,数据混合Agent在跨源领域和目标领域基准测试中实现了平衡的性能,优于强大的基线。此外,它可以在未见过的源领域、目标模型和领域空间中很好地泛化,而无需重新训练。直接应用于代码生成领域也表明了其在目标领域中的适应性。进一步的分析表明,该Agent的启发式方法与人类直觉高度一致,并且能够以更少的源领域数据实现卓越的模型性能。
🔬 方法详解
问题定义:论文旨在解决持续预训练过程中,模型在适应新领域知识时,容易遗忘原有领域知识的问题,即灾难性遗忘。现有的领域重加权方法依赖人工设计的启发式规则,这些规则通常基于经验或直觉,缺乏通用性和自适应性,难以在不同领域和模型之间迁移。
核心思路:论文的核心思路是将领域重加权过程建模为一个强化学习问题。通过训练一个Agent,使其能够根据当前模型的性能反馈,动态地调整不同领域数据的权重,从而在保留原有知识的同时,有效地学习新知识。这种方法避免了人工设计启发式规则的繁琐过程,并能够自动学习最优的重加权策略。
技术框架:整体框架包含三个主要组成部分:数据混合Agent、预训练模型和评估环境。数据混合Agent负责生成数据混合比例,预训练模型根据混合后的数据进行训练,评估环境则根据模型在源领域和目标领域的性能,提供反馈信号给Agent。Agent通过强化学习算法(如Policy Gradient),不断优化其重加权策略,以最大化模型在所有领域的综合性能。
关键创新:最重要的创新点在于将领域重加权问题转化为一个可学习的强化学习问题,并提出了数据混合Agent这一概念。与传统的基于人工启发式规则的方法相比,该方法能够自动学习最优的重加权策略,具有更强的通用性和自适应性。此外,该Agent可以在不同的源领域、目标模型和领域空间中进行泛化,无需重新训练。
关键设计:Agent的网络结构可以采用Transformer或其他序列模型,输入是当前模型的性能指标(如在源领域和目标领域的准确率),输出是不同领域数据的权重。损失函数采用强化学习中的奖励函数,奖励值可以根据模型在各个领域的性能进行设计,例如,可以采用加权平均准确率作为奖励值。在训练过程中,需要设计合适的探索策略,以鼓励Agent尝试不同的重加权策略,并避免陷入局部最优。
🖼️ 关键图片
📊 实验亮点
实验结果表明,数据混合Agent在数学推理任务上,能够显著提升持续预训练的性能,优于传统的基于人工启发式规则的方法。具体而言,该Agent在平衡源领域和目标领域的性能方面表现出色,并且具有良好的泛化能力,可以在未见过的源领域、目标模型和领域空间中进行泛化。此外,该Agent还能够以更少的源领域数据实现卓越的模型性能,表明其能够更有效地利用数据。
🎯 应用场景
该研究成果可广泛应用于各种需要持续学习的场景,例如:大型语言模型的持续预训练、机器人技能学习、自动驾驶等。通过自动学习领域重加权策略,可以有效提升模型在不断变化的环境中的适应能力,并降低人工干预的成本。此外,该方法还可以用于个性化推荐系统,根据用户的历史行为和偏好,动态调整不同类型内容的权重,从而提供更精准的推荐服务。
📄 摘要(原文)
Continual pre-training on small-scale task-specific data is an effective method for improving large language models in new target fields, yet it risks catastrophic forgetting of their original capabilities. A common solution is to re-weight training data mixtures from source and target fields on a domain space to achieve balanced performance. Previous domain reweighting strategies rely on manual designation with certain heuristics based on human intuition or empirical results. In this work, we prove that more general heuristics can be parameterized by proposing Data Mixing Agent, the first model-based, end-to-end framework that learns to re-weight domains. The agent learns generalizable heuristics through reinforcement learning on large quantities of data mixing trajectories with corresponding feedback from an evaluation environment. Experiments in continual pre-training on math reasoning show that Data Mixing Agent outperforms strong baselines in achieving balanced performance across source and target field benchmarks. Furthermore, it generalizes well across unseen source fields, target models, and domain spaces without retraining. Direct application to the code generation field also indicates its adaptability across target domains. Further analysis showcases the agents' well-aligned heuristics with human intuitions and their efficiency in achieving superior model performance with less source-field data.