SPT: Fine-Tuning Transformer-based Language Models Efficiently with Sparsification

📄 arXiv: 2312.10365v1 📥 PDF

作者: Yuntao Gui, Xiao Yan, Peiqi Yin, Han Yang, James Cheng

分类: cs.DC, cs.AI

发布日期: 2023-12-16

备注: Firstly submitted to VLDB November 1, 2023, rejection received on December 15, 2023


💡 一句话要点

SPT:通过稀疏化高效微调Transformer语言模型

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: Transformer 微调 稀疏化 多头注意力 前馈神经网络 模型压缩 高效计算

📋 核心要点

  1. Transformer模型微调面临内存消耗大和运行时间长的问题,限制了其在资源受限场景的应用。
  2. SPT通过引入稀疏性,设计了稀疏MHA和路由FFN模块,减少内存占用和计算量,提升微调效率。
  3. 实验结果表明,SPT在多种模型配置下均优于现有基线,显著降低内存消耗并加速微调过程。

📝 摘要(中文)

基于Transformer的大型语言模型(如BERT和GPT)取得了巨大成功,而微调是在特定任务数据集上调整预训练模型,是利用这些模型进行下游任务的标准做法。然而,由于模型规模庞大,Transformer微调的运行时间和内存消耗都很高。我们提出了SPT系统,通过引入稀疏性来高效地微调基于Transformer的模型。我们观察到,Transformer的内存消耗主要来自存储多头注意力(MHA)的注意力权重,而大部分运行时间花费在feed-forward network(FFN)上。因此,我们设计了稀疏MHA模块,它只计算和存储大的注意力权重以减少内存消耗,以及路由FFN模块,它为每个token动态激活模型参数的子集以减少计算成本。我们在PyTorch上实现了SPT,并定制了CUDA内核以高效运行稀疏MHA和路由FFN。具体来说,我们使用乘积量化来识别大的注意力权重,并通过稀疏矩阵乘法计算稀疏MHA的注意力。对于路由FFN,我们根据token激活的模型参数对token进行批处理,以实现高效计算。我们进行了广泛的实验,以评估SPT在各种模型配置上的性能。结果表明,SPT始终优于经过良好优化的基线,最多可减少50%的峰值内存消耗,并将微调速度提高高达2.2倍。

🔬 方法详解

问题定义:论文旨在解决Transformer模型微调过程中内存消耗过大和运行时间过长的问题。现有方法在微调大型Transformer模型时,需要消耗大量的计算资源和时间,限制了其在实际应用中的部署。主要瓶颈在于多头注意力机制(MHA)需要存储大量的注意力权重,以及前馈神经网络(FFN)需要进行大量的计算。

核心思路:论文的核心思路是通过引入稀疏性来减少内存消耗和计算量。具体来说,只保留重要的注意力权重,并动态激活部分模型参数,从而降低计算复杂度。这种方法能够在保证模型性能的同时,显著提高微调效率。

技术框架:SPT系统主要包含两个核心模块:稀疏MHA(Sparse MHA)和路由FFN(Routed FFN)。稀疏MHA通过只计算和存储重要的注意力权重来减少内存消耗。路由FFN则根据输入token动态激活一部分模型参数,从而减少计算量。整个流程包括:输入token -> 稀疏MHA -> 路由FFN -> 输出。

关键创新:论文的关键创新在于提出了稀疏MHA和路由FFN两种稀疏化方法。稀疏MHA通过乘积量化来识别重要的注意力权重,并使用稀疏矩阵乘法进行计算,从而减少内存占用。路由FFN则通过动态激活模型参数的子集,减少了计算量,提高了计算效率。与现有方法相比,SPT能够在保证模型性能的同时,显著降低内存消耗和加速微调过程。

关键设计:在稀疏MHA中,使用乘积量化来选择重要的注意力权重。在路由FFN中,根据token激活的模型参数对token进行批处理,以实现高效计算。具体实现中,论文在PyTorch上实现了SPT,并定制了CUDA内核以高效运行稀疏MHA和路由FFN。

📊 实验亮点

实验结果表明,SPT在各种模型配置下均优于现有基线。SPT最多可减少50%的峰值内存消耗,并将微调速度提高高达2.2倍。这些结果表明,SPT是一种高效的Transformer模型微调方法,具有很强的实用价值。

🎯 应用场景

SPT具有广泛的应用前景,可用于在资源受限的环境中高效微调大型Transformer模型,例如移动设备或边缘计算设备。该技术可以加速自然语言处理任务的开发和部署,并降低计算成本。未来,SPT可以进一步扩展到其他类型的深度学习模型和任务中。

📄 摘要(原文)

Transformer-based large language models (e.g., BERT and GPT) achieve great success, and fine-tuning, which tunes a pre-trained model on a task-specific dataset, is the standard practice to utilize these models for downstream tasks. However, Transformer fine-tuning has long running time and high memory consumption due to the large size of the models. We propose the SPT system to fine-tune Transformer-based models efficiently by introducing sparsity. We observe that the memory consumption of Transformer mainly comes from storing attention weights for multi-head attention (MHA), and the majority of running time is spent on feed-forward network (FFN). Thus, we design the sparse MHA module, which computes and stores only large attention weights to reduce memory consumption, and the routed FFN module, which dynamically activates a subset of model parameters for each token to reduce computation cost. We implement SPT on PyTorch and customize CUDA kernels to run sparse MHA and routed FFN efficiently. Specifically, we use product quantization to identify the large attention weights and compute attention via sparse matrix multiplication for sparse MHA. For routed FFN, we batch the tokens according to their activated model parameters for efficient computation. We conduct extensive experiments to evaluate SPT on various model configurations. The results show that SPT consistently outperforms well-optimized baselines, reducing the peak memory consumption by up to 50% and accelerating fine-tuning by up to 2.2x.