Does your data spark joy? Performance gains from domain upsampling at the end of training

📄 arXiv: 2406.03476v1 📥 PDF

作者: Cody Blakeney, Mansheej Paul, Brett W. Larsen, Sean Owen, Jonathan Frankle

分类: cs.LG, cs.CL

发布日期: 2024-06-05

备注: The first three authors contributed equally


💡 一句话要点

提出领域数据末端上采样方法,提升大语言模型在特定任务上的性能

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 大语言模型 领域自适应 数据上采样 预训练 知识迁移

📋 核心要点

  1. 现有方法难以平衡通用网络数据和领域数据,导致特定任务性能提升受限,且实验成本高昂。
  2. 论文提出在训练末期对领域数据进行上采样,以提高模型在特定任务上的性能,降低实验成本。
  3. 实验表明,该方法在MMLU、GSM8K和HumanEval等基准上显著提升性能,媲美训练时间更长的模型。

📝 摘要(中文)

大型语言模型(LLM)的预训练数据集已经增长到数万亿个token,其中包含大量的CommonCrawl(CC)网络抓取数据以及较小的、特定领域的数据集。理解这些特定领域数据集对模型能力的影响非常昂贵,因为需要在大型FLOP规模上进行训练才能揭示对困难和新兴基准的显著改变。鉴于预训练数据实验成本的不断增加,如何确定通用网络抓取的多元化与特定领域数据的信息密度之间的最佳平衡?在这项工作中,我们展示了如何通过在训练结束时对特定领域的数据集进行相对于CC的上采样,来提高困难基准上的性能。这种简单技术使我们能够在MMLU上提高高达6.90个百分点,在GSM8K上提高8.26个百分点,在HumanEval上提高6.17个百分点,相对于训练了1万亿个token的7B模型的基础数据混合,从而与Llama-2(7B)相媲美——一个训练时间是其两倍的模型。我们实验了从训练的5%到30%的领域上采样持续时间,发现10%到20%对于在通用语言建模能力和目标基准之间进行权衡是最佳的。我们还使用领域上采样来大规模地表征各个数据集对于改进各种基准的效用,方法是在训练的最后阶段删除它们。该工具开启了以比完整预训练运行低一个数量级的成本来实验不同预训练数据集影响的能力。

🔬 方法详解

问题定义:现有的大语言模型预训练通常混合通用网络数据(如CommonCrawl)和少量领域特定数据。由于训练成本高昂,难以有效评估和利用领域特定数据对模型在特定任务上的性能提升潜力。现有方法难以在通用能力和特定任务性能之间取得平衡,且探索不同数据组合的成本过高。

核心思路:论文的核心思路是在预训练的最后阶段,对领域特定数据进行上采样。这种方法的核心在于,在模型已经具备一定的通用语言能力后,通过增加领域特定数据的比例,使模型更加专注于学习这些数据中的知识,从而提升在相关任务上的表现。这种方法比从头开始训练或在整个训练过程中保持固定比例的领域数据更高效。

技术框架:整体框架包括以下阶段:1) 使用混合了通用数据和领域数据的标准预训练流程训练模型至一定程度(例如,1万亿token);2) 在训练的最后阶段,提高领域特定数据在训练数据中的比例(上采样);3) 继续训练模型一段时间(例如,总训练的5%-30%);4) 在目标任务上评估模型性能。关键模块是数据采样策略,即如何确定领域数据的上采样比例和持续时间。

关键创新:该方法最重要的创新点在于,它提供了一种高效的方式来利用领域特定数据,而无需从头开始进行昂贵的预训练。通过在训练末期进行上采样,可以在不显著增加计算成本的情况下,显著提高模型在特定任务上的性能。此外,该方法还提供了一种评估不同数据集效用的方法,即通过在末端上采样阶段移除特定数据集,观察对性能的影响。

关键设计:关键设计包括:1) 确定最佳的领域数据上采样比例。论文实验了不同的上采样比例,发现10%-20%的训练时长用于上采样是比较合适的。2) 选择合适的领域数据集。论文使用领域上采样来评估不同数据集对特定任务的贡献,例如,移除某个数据集观察性能变化。3) 损失函数和网络结构保持不变,主要关注数据配比的影响。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,通过在训练末期对领域数据进行上采样,7B模型在MMLU上提升高达6.90个百分点,在GSM8K上提升8.26个百分点,在HumanEval上提升6.17个百分点,性能媲美训练时间是其两倍的Llama-2 (7B)模型。实验还发现,10%-20%的训练时长用于领域数据上采样是最佳选择。

🎯 应用场景

该研究成果可应用于各种需要特定领域知识的大语言模型训练场景,例如金融、医疗、法律等。通过末端上采样,可以低成本地提升模型在特定领域的专业能力,提高模型在相关任务上的准确性和效率。该方法还有助于优化预训练数据集的构成,降低训练成本,加速模型迭代。

📄 摘要(原文)

Pretraining datasets for large language models (LLMs) have grown to trillions of tokens composed of large amounts of CommonCrawl (CC) web scrape along with smaller, domain-specific datasets. It is expensive to understand the impact of these domain-specific datasets on model capabilities as training at large FLOP scales is required to reveal significant changes to difficult and emergent benchmarks. Given the increasing cost of experimenting with pretraining data, how does one determine the optimal balance between the diversity in general web scrapes and the information density of domain specific data? In this work, we show how to leverage the smaller domain specific datasets by upsampling them relative to CC at the end of training to drive performance improvements on difficult benchmarks. This simple technique allows us to improve up to 6.90 pp on MMLU, 8.26 pp on GSM8K, and 6.17 pp on HumanEval relative to the base data mix for a 7B model trained for 1 trillion (T) tokens, thus rivaling Llama-2 (7B)$\unicode{x2014}$a model trained for twice as long. We experiment with ablating the duration of domain upsampling from 5% to 30% of training and find that 10% to 20% percent is optimal for navigating the tradeoff between general language modeling capabilities and targeted benchmarks. We also use domain upsampling to characterize at scale the utility of individual datasets for improving various benchmarks by removing them during this final phase of training. This tool opens up the ability to experiment with the impact of different pretraining datasets at scale, but at an order of magnitude lower cost compared to full pretraining runs.