ThinK: Thinner Key Cache by Query-Driven Pruning
作者: Yuhui Xu, Zhanming Jie, Hanze Dong, Lei Wang, Xudong Lu, Aojun Zhou, Amrita Saha, Caiming Xiong, Doyen Sahoo
分类: cs.CL, cs.AI
发布日期: 2024-07-30 (更新: 2025-02-27)
备注: ICLR 2025 (Spotlight)
🔗 代码/项目: GITHUB
💡 一句话要点
ThinK:一种查询驱动的KV缓存剪枝方法,用于减少长序列LLM推理的内存占用。
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: KV缓存剪枝 长序列建模 大型语言模型 内存优化 查询驱动 推理加速
📋 核心要点
- 现有方法主要基于序列长度优化KV缓存,忽略了通道维度上的冗余,导致内存利用率不高。
- ThinK提出一种查询驱动的KV缓存剪枝方法,通过选择性地剪枝不重要的通道来减少内存占用,同时保持模型精度。
- 实验表明,ThinK能显著降低KV缓存内存成本,例如与KIVI集成时,峰值内存使用量减少2.8倍,批量大小增加5倍。
📝 摘要(中文)
大型语言模型(LLM)彻底改变了自然语言处理领域,在各种应用中取得了前所未有的性能。然而,它们不断增长的计算和内存需求带来了重大挑战,尤其是在处理长序列时。本文着重于长上下文场景,旨在解决推理过程中KV缓存内存消耗的低效问题。与现有基于序列长度优化内存的方法不同,我们发现KV缓存在通道维度上存在大量冗余,这体现在注意力权重的不均匀幅度分布和低秩结构中。为此,我们提出ThinK,一种新颖的查询相关的KV缓存剪枝方法,旨在最小化注意力权重损失,同时选择性地剪枝最不重要的通道。我们的方法不仅保持或提高了模型精度,而且与原始KV缓存驱逐和量化方法相比,KV缓存内存成本降低了20%以上。例如,ThinK与KIVI集成可以实现2.8倍的峰值内存使用量减少,同时保持几乎相同的质量,从而在使用单个GPU时可以将批量大小增加高达5倍。在各种长序列数据集上对LLaMA和Mistral模型进行的大量评估验证了ThinK的效率,为高效LLM部署建立了一个新的基线算法,且不影响性能。我们的代码已在https://github.com/SalesforceAIResearch/ThinK上提供。
🔬 方法详解
问题定义:论文旨在解决大型语言模型在长序列推理过程中,KV缓存占用大量内存的问题。现有方法主要关注序列长度上的优化,忽略了KV缓存在通道维度上的冗余,导致内存利用率低下,限制了模型部署和应用。
核心思路:论文的核心思路是观察到KV缓存在通道维度上存在冗余,注意力权重呈现不均匀的幅度分布和低秩结构。因此,可以通过剪枝KV缓存中不重要的通道来减少内存占用,同时尽量保持模型的性能。剪枝过程需要依赖查询向量,以保证剪枝后的KV缓存能够更好地服务于当前的推理任务。
技术框架:ThinK的核心流程是:首先,对于每个查询向量,计算其与KV缓存中各个通道的相关性;然后,根据相关性对通道进行排序,并剪枝掉相关性较低的通道;最后,使用剪枝后的KV缓存进行推理。该框架可以与现有的KV缓存优化方法(如KIVI)相结合,进一步提升性能。
关键创新:ThinK的关键创新在于提出了查询驱动的KV缓存剪枝方法。与传统的静态剪枝方法不同,ThinK根据当前的查询向量动态地选择需要剪枝的通道,从而更好地适应不同的输入序列和推理任务。这种方法能够在减少内存占用的同时,最大限度地保持模型的性能。
关键设计:ThinK的关键设计包括:1) 使用注意力权重来衡量查询向量与KV缓存通道的相关性;2) 设计了一种损失函数,用于指导剪枝过程,以最小化注意力权重损失;3) 采用了一种自适应的剪枝策略,根据不同的查询向量和KV缓存状态,动态地调整剪枝比例。
🖼️ 关键图片
📊 实验亮点
实验结果表明,ThinK在LLaMA和Mistral模型上均取得了显著的性能提升。与原始KV缓存驱逐和量化方法相比,ThinK可以将KV缓存内存成本降低20%以上。与KIVI集成后,ThinK可以实现2.8倍的峰值内存使用量减少,同时保持几乎相同的质量,并可以将单个GPU上的批量大小增加高达5倍。这些结果表明,ThinK是一种高效且实用的LLM部署优化方法。
🎯 应用场景
ThinK可应用于各种需要处理长序列的LLM应用场景,如长文本摘要、机器翻译、对话系统等。通过降低内存占用,ThinK能够使LLM在资源受限的设备上运行,并提高推理效率,从而加速LLM的部署和应用。此外,ThinK还可以与其他KV缓存优化技术相结合,进一步提升性能。
📄 摘要(原文)
Large Language Models (LLMs) have revolutionized the field of natural language processing, achieving unprecedented performance across a variety of applications. However, their increased computational and memory demands present significant challenges, especially when handling long sequences. This paper focuses on the long-context scenario, addressing the inefficiencies in KV cache memory consumption during inference. Unlike existing approaches that optimize the memory based on the sequence length, we identify substantial redundancy in the channel dimension of the KV cache, as indicated by an uneven magnitude distribution and a low-rank structure in the attention weights. In response, we propose ThinK, a novel query-dependent KV cache pruning method designed to minimize attention weight loss while selectively pruning the least significant channels. Our approach not only maintains or enhances model accuracy but also achieves a reduction in KV cache memory costs by over 20% compared with vanilla KV cache eviction and quantization methods. For instance, ThinK integrated with KIVI can achieve a 2.8x reduction in peak memory usage while maintaining nearly the same quality, enabling up to a 5x increase in batch size when using a single GPU. Extensive evaluations on the LLaMA and Mistral models across various long-sequence datasets verified the efficiency of ThinK, establishing a new baseline algorithm for efficient LLM deployment without compromising performance. Our code has been made available at https://github.com/SalesforceAIResearch/ThinK.