Transferable text data distillation by trajectory matching
作者: Rong Yao, Hailin Hu, Yifei Fu, Hanting Chen, Wenyi Fang, Fanyi Du, Kai Han, Yunhe Wang
分类: cs.CL
发布日期: 2025-04-14 (更新: 2025-04-24)
💡 一句话要点
提出基于轨迹匹配的可迁移文本数据蒸馏方法,降低大语言模型训练成本。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 数据蒸馏 文本生成 指令调优 轨迹匹配 大语言模型
📋 核心要点
- 大型语言模型训练成本高昂,数据蒸馏旨在用少量合成数据达到全量数据训练效果,但在NLP领域因文本离散性面临挑战。
- 该方法通过轨迹匹配学习伪提示数据,并寻找最近邻ID实现跨架构迁移,同时引入正则化损失提高鲁棒性。
- 实验表明,该方法在ARC-Easy和MMLU数据集上优于SOTA数据选择方法,并展现出良好的跨模型迁移能力。
📝 摘要(中文)
随着大型语言模型(LLM)规模的增大,训练成本也随之增加。因此,迫切需要最小化LLM训练中的数据规模。与数据选择方法相比,数据蒸馏方法旨在合成少量数据样本,以达到完整数据集的训练效果,并具有更好的灵活性。尽管数据蒸馏在计算机视觉领域取得了成功,但文本数据的离散性阻碍了其在自然语言处理(NLP)领域的探索。本文提出了一种基于轨迹匹配学习伪提示数据,并通过寻找其最近邻ID来实现跨架构迁移的方法。在蒸馏过程中,引入正则化损失以提高蒸馏数据的鲁棒性。据我们所知,这是第一个适用于指令调优等文本生成任务的数据蒸馏工作。在ARC-Easy和MMLU指令调优数据集上的评估表明,我们的蒸馏方法优于SOTA数据选择方法LESS。此外,我们的方法展示了良好的跨LLM结构(即,OPT到Llama)的可迁移性。
🔬 方法详解
问题定义:现有的大型语言模型(LLM)训练需要大量数据,导致训练成本高昂。数据选择方法虽然可以减少数据量,但效果有限。数据蒸馏方法在计算机视觉领域表现出色,但由于文本数据的离散性,直接应用到自然语言处理(NLP)领域存在困难。因此,需要一种适用于文本数据,特别是指令调优任务的数据蒸馏方法,以降低LLM的训练成本。
核心思路:该论文的核心思路是利用轨迹匹配来学习伪提示数据,并将其转化为离散的文本ID。具体来说,通过优化伪提示,使得使用这些伪提示训练的小模型在性能上尽可能接近使用原始数据训练的大模型。然后,通过寻找伪提示的最近邻ID,将其转化为可用于训练的离散文本数据。这种方法能够有效地将原始数据的知识迁移到少量合成数据中。
技术框架:该方法主要包含以下几个阶段:1) 初始化伪提示:随机初始化一组伪提示数据。2) 轨迹匹配:使用伪提示数据训练一个小模型,并计算其训练轨迹与使用原始数据训练的大模型的训练轨迹之间的距离。3) 优化伪提示:通过最小化轨迹距离来优化伪提示数据。4) 最近邻ID查找:将优化后的伪提示数据映射到离散的文本ID空间,找到其最近邻ID。5) 正则化:引入正则化损失,提高蒸馏数据的鲁棒性。
关键创新:该论文的关键创新在于将轨迹匹配的思想引入到文本数据蒸馏中。与以往的数据蒸馏方法不同,该方法不需要直接操作离散的文本数据,而是通过优化连续的伪提示数据来实现知识迁移。此外,该方法还引入了正则化损失,提高了蒸馏数据的鲁棒性,使其能够更好地泛化到不同的模型结构上。
关键设计:在轨迹匹配阶段,使用了余弦相似度来衡量训练轨迹之间的距离。在优化伪提示数据时,使用了Adam优化器。为了提高蒸馏数据的鲁棒性,引入了L2正则化损失。在最近邻ID查找阶段,使用了k-NN算法。具体参数设置(如学习率、正则化系数、k值等)需要根据具体的实验进行调整。
🖼️ 关键图片
📊 实验亮点
该论文在ARC-Easy和MMLU指令调优数据集上进行了实验,结果表明,提出的数据蒸馏方法优于SOTA数据选择方法LESS。此外,该方法还展示了良好的跨LLM结构的可迁移性,例如可以将使用OPT模型蒸馏得到的数据用于训练Llama模型,并且能够取得良好的性能。
🎯 应用场景
该研究成果可应用于各种需要降低大语言模型训练成本的场景,例如在资源受限的环境下训练特定领域的LLM,或者在边缘设备上部署轻量级的LLM。此外,该方法还可以用于数据增强,通过生成高质量的合成数据来提高模型的泛化能力。未来,该方法有望推广到更多NLP任务中,例如机器翻译、文本摘要等。
📄 摘要(原文)
In the realm of large language model (LLM), as the size of large models increases, it also brings higher training costs. There is a urgent need to minimize the data size in LLM training. Compared with data selection method, the data distillation method aims to synthesize a small number of data samples to achieve the training effect of the full data set and has better flexibility. Despite its successes in computer vision, the discreteness of text data has hitherto stymied its exploration in natural language processing (NLP). In this work, we proposed a method that involves learning pseudo prompt data based on trajectory matching and finding its nearest neighbor ID to achieve cross-architecture transfer. During the distillation process, we introduce a regularization loss to improve the robustness of our distilled data. To our best knowledge, this is the first data distillation work suitable for text generation tasks such as instruction tuning. Evaluations on two benchmarks, including ARC-Easy and MMLU instruction tuning datasets, established the superiority of our distillation approach over the SOTA data selection method LESS. Furthermore, our method demonstrates a good transferability over LLM structures (i.e., OPT to Llama).