Joint Selection for Large-Scale Pre-Training Data via Policy Gradient-based Mask Learning
作者: Ziqing Fan, Yuqiao Xian, Yan Sun, Li Shen
分类: cs.CL, cs.LG
发布日期: 2025-12-30
💡 一句话要点
提出DATAMASK,通过策略梯度优化大规模预训练数据联合选择,提升训练效率和模型性能。
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 大规模预训练 数据选择 策略梯度 掩码学习 质量多样性 语言模型 模型优化
📋 核心要点
- 现有大规模预训练数据选择方法难以兼顾质量和多样性,导致训练效率和模型性能受限。
- DATAMASK将数据选择视为掩码学习问题,通过策略梯度优化,联合优化质量和多样性指标。
- 实验表明,DATAMASK在万亿级token数据集上显著提升了1.5B和7B模型的性能,并大幅缩减了选择时间。
📝 摘要(中文)
大规模语言模型的预训练需要精细的数据配方,以显著提高训练效率和模型性能。其中一个重要组成部分是基于规则、LLM判断或嵌入统计信息产生的分数来选择样本,这些分数大致可分为质量和多样性指标。由于应用于FineWeb和DCLM等万亿级token预训练数据集时计算成本高昂,因此很少在单个选择过程中同时考虑这两种或多种类型的指标。然而,我们的实证研究表明,基于质量指标选择样本在长期预训练过程中表现出严重的收益递减,而基于多样性指标进行选择会删除太多有价值的高质量样本,这两者都限制了预训练LLM的能力。因此,我们引入了DATAMASK,这是一种新颖且高效的联合学习框架,专为大规模预训练数据选择而设计,可以在统一的过程中同时优化多种类型的指标,本研究特别关注质量和多样性指标。DATAMASK将选择过程视为掩码学习问题,涉及迭代采样数据掩码,基于具有采样掩码的预定义目标计算策略梯度,以及更新掩码采样logits。通过基于策略梯度的优化和各种加速增强,与贪婪算法相比,它显着减少了98.9%的选择时间,使我们的研究能够探索万亿级token内的联合学习。借助DATAMASK,我们从15万亿token的FineWeb数据集中选择了一个约10%的子集,称为FineWeb-Mask。在12个不同的任务中进行评估,我们在1.5B密集模型上实现了3.2%的显着改进,在7B MoE模型上实现了1.9%的显着改进。
🔬 方法详解
问题定义:现有的大规模预训练数据选择方法,如基于质量或多样性的选择,存在各自的局限性。单独基于质量的选择会面临收益递减,而单独基于多样性的选择会移除有价值的高质量样本。因此,如何高效地联合优化质量和多样性指标,成为一个关键问题。现有方法计算成本高昂,难以应用于万亿级token数据集。
核心思路:DATAMASK的核心思路是将数据选择问题转化为一个掩码学习问题。通过学习一个数据掩码,来决定哪些数据被选择用于预训练。这种方法允许同时考虑多个指标(如质量和多样性),并通过策略梯度优化来找到最优的掩码策略。这种设计旨在克服传统方法的局限性,实现更高效和更有效的预训练数据选择。
技术框架:DATAMASK的整体框架包括以下几个主要阶段:1) 初始化数据掩码采样logits;2) 迭代采样数据掩码;3) 基于采样的掩码和预定义的目标函数计算策略梯度;4) 使用策略梯度更新掩码采样logits。通过迭代执行这些步骤,DATAMASK能够逐步优化数据选择策略,从而选择出更适合预训练的数据子集。
关键创新:DATAMASK最重要的技术创新点在于其将数据选择问题建模为掩码学习问题,并利用策略梯度进行优化。与传统的基于规则或启发式方法的数据选择相比,DATAMASK能够通过学习自动地找到最优的数据选择策略,从而更好地平衡质量和多样性。此外,DATAMASK还采用了各种加速增强技术,以降低计算成本,使其能够应用于大规模数据集。
关键设计:DATAMASK的关键设计包括:1) 使用策略梯度算法(如REINFORCE)来优化掩码采样logits;2) 定义合适的目标函数,以同时考虑质量和多样性指标;3) 采用高效的采样和梯度计算方法,以降低计算成本;4) 设计合适的网络结构来学习数据掩码。
🖼️ 关键图片
📊 实验亮点
DATAMASK在15万亿token的FineWeb数据集上进行了实验,结果表明,与贪婪算法相比,DATAMASK将数据选择时间减少了98.9%。使用DATAMASK选择的10%数据子集(FineWeb-Mask)训练的1.5B密集模型在12个不同任务上取得了3.2%的性能提升,7B MoE模型取得了1.9%的性能提升。
🎯 应用场景
DATAMASK可应用于大规模语言模型的预训练数据选择,提高预训练效率和模型性能。该方法能够有效应用于各种需要大规模数据预训练的场景,例如自然语言处理、计算机视觉等领域。通过优化数据选择,可以降低训练成本,提升模型泛化能力,加速AI技术的落地应用。
📄 摘要(原文)
A fine-grained data recipe is crucial for pre-training large language models, as it can significantly enhance training efficiency and model performance. One important ingredient in the recipe is to select samples based on scores produced by defined rules, LLM judgment, or statistical information in embeddings, which can be roughly categorized into quality and diversity metrics. Due to the high computational cost when applied to trillion-scale token pre-training datasets such as FineWeb and DCLM, these two or more types of metrics are rarely considered jointly in a single selection process. However, in our empirical study, selecting samples based on quality metrics exhibit severe diminishing returns during long-term pre-training, while selecting on diversity metrics removes too many valuable high-quality samples, both of which limit pre-trained LLMs' capabilities. Therefore, we introduce DATAMASK, a novel and efficient joint learning framework designed for large-scale pre-training data selection that can simultaneously optimize multiple types of metrics in a unified process, with this study focusing specifically on quality and diversity metrics. DATAMASK approaches the selection process as a mask learning problem, involving iterative sampling of data masks, computation of policy gradients based on predefined objectives with sampled masks, and updating of mask sampling logits. Through policy gradient-based optimization and various acceleration enhancements, it significantly reduces selection time by 98.9% compared to greedy algorithm, enabling our study to explore joint learning within trillion-scale tokens. With DATAMASK, we select a subset of about 10% from the 15 trillion-token FineWeb dataset, termed FineWeb-Mask. Evaluated across 12 diverse tasks, we achieves significant improvements of 3.2% on a 1.5B dense model and 1.9% on a 7B MoE model.