LoLCATs: On Low-Rank Linearizing of Large Language Models
作者: Michael Zhang, Simran Arora, Rahul Chalamala, Alan Wu, Benjamin Spector, Aaryan Singhal, Krithik Ramesh, Christopher Ré
分类: cs.LG, cs.AI, cs.CL, stat.ML
发布日期: 2024-10-14 (更新: 2025-03-05)
备注: 58 pages, 25 figures, 26 tables, ICLR 2025
💡 一句话要点
LoLCATs:通过低秩线性化方法提升大型语言模型的效率与性能
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 大型语言模型 线性化 注意力机制 低秩自适应 注意力迁移 模型压缩 高效推理
📋 核心要点
- 现有线性化LLM的方法在模型质量上存在显著下降,且训练成本高昂,并局限于较小的模型规模(1.3B-7B)。
- LoLCATs通过注意力迁移和低秩自适应,在保证模型质量的同时,显著降低了线性化LLM的计算和内存需求。
- 实验表明,LoLCATs在多个模型和数据集上取得了显著的性能提升,并成功应用于更大规模的LLM(70B和405B)。
📝 摘要(中文)
本文提出了一种名为LoLCATs(通过注意力迁移的低秩线性转换)的两步方法,旨在提高大型语言模型(LLM)线性化的质量,同时显著降低内存和计算成本。该方法基于两个发现:首先,通过训练线性注意力机制以匹配softmax注意力机制的输出(注意力迁移),可以有效地用线性注意力替代LLM中的softmax注意力。其次,通过低秩自适应(LoRA)可以调整近似误差并恢复LLM的质量。LoLCATs显著提高了线性化的质量、训练效率和可扩展性,在Llama 3 8B和Mistral 7B v0.1上取得了最先进的亚二次复杂度LLM,并在5-shot MMLU上实现了超过20个点的改进。此外,LoLCATs仅使用了先前方法0.2%的模型参数和0.4%的训练token。最后,LoLCATs被应用于创建首个线性化的70B和405B LLM(比先前工作大50倍)。在相同的计算预算下,LoLCATs显著提高了线性化的质量,在5-shot MMLU上将线性化后的Llama 3.1 70B和405B LLM与原始模型之间的差距缩小了77.8%和78.1%。
🔬 方法详解
问题定义:论文旨在解决大型语言模型(LLM)线性化过程中模型质量下降、训练成本高昂以及模型规模受限的问题。现有方法通常需要在数十亿token上进行训练,并且无法有效地扩展到更大规模的LLM。
核心思路:论文的核心思路是通过两步法:首先,使用注意力迁移将softmax注意力替换为近似的线性注意力,从而降低计算复杂度;然后,利用低秩自适应(LoRA)来调整由于线性化带来的近似误差,从而恢复模型性能。这种方法旨在在降低计算成本的同时,保持甚至提升模型质量。
技术框架:LoLCATs方法包含两个主要阶段:1) 注意力迁移:训练线性注意力机制,使其输出尽可能接近原始softmax注意力机制的输出。这通常通过最小化输出的均方误差(MSE)来实现。2) 低秩自适应(LoRA):在经过注意力迁移的线性化模型上应用LoRA,通过引入少量可训练的低秩矩阵来调整模型参数,从而弥补线性化带来的性能损失。
关键创新:该方法最重要的创新点在于将注意力迁移和低秩自适应相结合,从而在降低计算复杂度的同时,有效地恢复了模型性能。与现有方法相比,LoLCATs能够以更少的参数和更少的训练数据,实现更高的线性化质量和更好的可扩展性。
关键设计:在注意力迁移阶段,关键在于选择合适的线性注意力机制和损失函数。论文使用MSE损失来衡量线性注意力与softmax注意力之间的差异。在LoRA阶段,关键在于选择合适的秩(rank)和学习率,以平衡模型性能和训练成本。此外,论文还探索了不同的线性注意力机制和LoRA配置,以优化模型性能。
🖼️ 关键图片
📊 实验亮点
LoLCATs在Llama 3 8B和Mistral 7B v0.1上取得了显著的性能提升,在5-shot MMLU上实现了超过20个点的改进。与现有方法相比,LoLCATs仅使用了0.2%的模型参数和0.4%的训练token。此外,LoLCATs成功应用于创建首个线性化的70B和405B LLM,并在相同的计算预算下,将线性化后的Llama 3.1 70B和405B LLM与原始模型之间的差距缩小了77.8%和78.1%。
🎯 应用场景
LoLCATs方法可广泛应用于需要高效推理和部署的大型语言模型场景,例如移动设备上的自然语言处理、边缘计算环境下的智能助手、以及对计算资源有限制的云服务。该方法降低了LLM的部署成本,使其能够服务于更广泛的用户群体,并促进了LLM在实际应用中的普及。
📄 摘要(原文)
Recent works show we can linearize large language models (LLMs) -- swapping the quadratic attentions of popular Transformer-based LLMs with subquadratic analogs, such as linear attention -- avoiding the expensive pretraining costs. However, linearizing LLMs often significantly degrades model quality, still requires training over billions of tokens, and remains limited to smaller 1.3B to 7B LLMs. We thus propose Low-rank Linear Conversion via Attention Transfer (LoLCATs), a simple two-step method that improves LLM linearizing quality with orders of magnitudes less memory and compute. We base these steps on two findings. First, we can replace an LLM's softmax attentions with closely-approximating linear attentions, simply by training the linear attentions to match their softmax counterparts with an output MSE loss ("attention transfer"). Then, this enables adjusting for approximation errors and recovering LLM quality simply with low-rank adaptation (LoRA). LoLCATs significantly improves linearizing quality, training efficiency, and scalability. We significantly reduce the linearizing quality gap and produce state-of-the-art subquadratic LLMs from Llama 3 8B and Mistral 7B v0.1, leading to 20+ points of improvement on 5-shot MMLU. Furthermore, LoLCATs does so with only 0.2% of past methods' model parameters and 0.4% of their training tokens. Finally, we apply LoLCATs to create the first linearized 70B and 405B LLMs (50x larger than prior work). When compared with prior approaches under the same compute budgets, LoLCATs significantly improves linearizing quality, closing the gap between linearized and original Llama 3.1 70B and 405B LLMs by 77.8% and 78.1% on 5-shot MMLU.