Correlation-Aware Select and Merge Attention for Efficient Fine-Tuning and Context Length Extension

📄 arXiv: 2410.04211v1 📥 PDF

作者: Ning Wang, Zekun Li, Tongxin Bai, Guoqi Li

分类: cs.CL, cs.AI

发布日期: 2024-10-05

备注: 11 pages, 2 figures


💡 一句话要点

提出相关性感知选择与合并注意力机制,高效微调并扩展LLM上下文长度。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 长文本建模 上下文长度扩展 稀疏注意力 位置编码 高效微调

📋 核心要点

  1. 现有大型语言模型扩展上下文长度面临技术和资源挑战,需要更高效的方法。
  2. 论文提出相关性感知选择与合并注意力机制,结合位置编码增强泛化能力,实现高效稀疏注意力。
  3. 实验表明,该方法能以更少资源在更长序列上微调模型,并在长上下文推理中保持高性能。

📝 摘要(中文)

本文提出了一种高效且灵活的注意力架构,旨在以更少的计算资源和微调时间扩展大型语言模型的上下文长度。该方法引入了相关性感知的选择和合并机制,以实现高效的稀疏注意力。此外,还提出了一种新颖的数据增强技术,涉及位置编码,以增强模型对未见位置的泛化能力。实验结果表明,使用单个A100 GPU,可以在32K序列长度上对Llama2-7B进行微调,效率优于其他依赖子集回归的方法。该方法还提供了一种全面的上下文长度扩展方案,涵盖预训练、微调和推理阶段。在预训练期间,注意力机制在token选择时部分打破了平移不变性,因此仅对选定的token应用位置编码。在微调阶段,引入了循环、随机截断和动态增长的NTK位置嵌入(CRD NTK)。该设计允许仅使用16K的序列长度进行微调,使Llama2-7B和Mistral-7B等模型能够执行高达1M甚至任意长度的上下文推理。该方法在4M上下文长度的passkey任务上实现了100%的准确率,并在1M上下文长度下保持了稳定的困惑度。与传统的全注意力机制相比,资源需求至少降低了64倍,同时仍能实现具有竞争力的性能。

🔬 方法详解

问题定义:现有大型语言模型在处理长序列时面临计算资源和技术上的挑战。传统的全注意力机制计算复杂度高,难以扩展到更长的上下文长度。现有的稀疏注意力方法虽然降低了计算量,但在性能上有所损失,或者需要复杂的训练策略。因此,如何高效地扩展大型语言模型的上下文长度,同时保持或提升性能,是一个亟待解决的问题。

核心思路:论文的核心思路是通过相关性感知的选择和合并机制,实现高效的稀疏注意力。具体来说,首先根据token之间的相关性选择一部分重要的token,然后对这些token进行注意力计算。这样可以减少计算量,同时保留重要的信息。此外,论文还提出了一种新的位置编码方法,以增强模型对未见位置的泛化能力。

技术框架:该方法主要包含三个阶段:预训练、微调和推理。在预训练阶段,使用相关性感知的选择和合并注意力机制训练模型。在微调阶段,使用循环、随机截断和动态增长的NTK位置嵌入(CRD NTK)对模型进行微调,以扩展上下文长度。在推理阶段,可以使用扩展后的上下文长度进行推理。整体架构是在Transformer的基础上,用提出的注意力机制替换了原有的全注意力机制。

关键创新:该方法最重要的技术创新点在于相关性感知的选择和合并注意力机制。与传统的稀疏注意力方法不同,该方法不是随机地选择token,而是根据token之间的相关性进行选择。这样可以保留更重要的信息,从而提高性能。此外,CRD NTK位置编码也是一个创新点,它允许模型在微调阶段学习到更长的上下文长度。

关键设计:在相关性感知的选择机制中,使用一个可学习的矩阵来计算token之间的相关性。选择top-k个相关性最高的token进行注意力计算。在合并机制中,将选择的token合并成一个更小的集合,以进一步减少计算量。CRD NTK位置编码通过循环、随机截断和动态增长的方式,使得模型可以学习到更长的上下文长度,同时避免了位置编码的冲突。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,该方法在Llama2-7B上实现了32K序列长度的微调,效率优于其他方法。在4M上下文长度的passkey任务上实现了100%的准确率,并在1M上下文长度下保持了稳定的困惑度。与传统的全注意力机制相比,资源需求至少降低了64倍,同时仍能实现具有竞争力的性能。这些结果表明,该方法在扩展上下文长度和降低计算资源需求方面具有显著优势。

🎯 应用场景

该研究成果可广泛应用于需要处理长序列的自然语言处理任务中,例如长文本摘要、机器翻译、对话系统、代码生成等。通过扩展模型的上下文长度,可以提高模型对长文本的理解能力,从而提升任务的性能。此外,该方法还可以降低计算资源的需求,使得在资源有限的条件下也能训练和部署大型语言模型。未来,该方法有望推动大型语言模型在更多领域的应用。

📄 摘要(原文)

Modeling long sequences is crucial for various large-scale models; however, extending existing architectures to handle longer sequences presents significant technical and resource challenges. In this paper, we propose an efficient and flexible attention architecture that enables the extension of context lengths in large language models with reduced computational resources and fine-tuning time compared to other excellent methods. Specifically, we introduce correlation-aware selection and merging mechanisms to facilitate efficient sparse attention. In addition, we also propose a novel data augmentation technique involving positional encodings to enhance generalization to unseen positions. The results are as follows: First, using a single A100, we achieve fine-tuning on Llama2-7B with a sequence length of 32K, which is more efficient than other methods that rely on subsets for regression. Second, we present a comprehensive method for extending context lengths across the pre-training, fine-tuning, and inference phases. During pre-training, our attention mechanism partially breaks translation invariance during token selection, so we apply positional encodings only to the selected tokens. This approach achieves relatively high performance and significant extrapolation capabilities. For fine-tuning, we introduce Cyclic, Randomly Truncated, and Dynamically Growing NTK Positional Embedding (CRD NTK). This design allows fine-tuning with a sequence length of only 16K, enabling models such as Llama2-7B and Mistral-7B to perform inference with context lengths of up to 1M or even arbitrary lengths. Our method achieves 100\% accuracy on the passkey task with a context length of 4M and maintains stable perplexity at a 1M context length. This represents at least a 64-fold reduction in resource requirements compared to traditional full-attention mechanisms, while still achieving competitive performance.