TokenButler: Token Importance is Predictable

📄 arXiv: 2503.07518v1 📥 PDF

作者: Yash Akhauri, Ahmed F AbouElhamayed, Yifei Gao, Chi-Chih Chang, Nilesh Jain, Mohamed S. Abdelfattah

分类: cs.CL, cs.AI, cs.LG

发布日期: 2025-03-10

🔗 代码/项目: GITHUB


💡 一句话要点

TokenButler:提出一种可预测Token重要性的方法,缓解LLM KV-Cache瓶颈。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 大语言模型 KV-Cache Token重要性 上下文管理 模型优化

📋 核心要点

  1. 现有LLM的KV-Cache机制存在内存和计算瓶颈,关键在于有效识别并保留重要token。
  2. TokenButler通过训练轻量级预测器,根据上下文和预测重要性对token进行优先级排序。
  3. 实验表明,TokenButler在困惑度和下游任务准确率上优于现有方法,并在共指检索任务中表现出色。

📝 摘要(中文)

大型语言模型(LLM)依赖于Key-Value(KV)缓存来存储token历史,从而实现高效的token解码。随着KV-Cache的增长,它成为了主要的内存和计算瓶颈。然而,由于先前的研究表明只有一小部分token对每个解码步骤有意义的贡献,因此存在缓解此瓶颈的机会。找到这些关键token的一个主要挑战是它们是动态的,并且严重依赖于输入查询。现有方法要么通过永久删除token来冒险降低质量,要么保留完整的KV-Cache,但在生成时依赖于检索token的块(页面),这在密集、上下文丰富的任务中会失败。此外,许多现有的KV-Cache稀疏方法依赖于不准确的token重要性代理。为了解决这些限制,我们引入了TokenButler,这是一种高粒度、查询感知的预测器,它学习识别这些关键token。通过训练一个参数开销小于1.2%的轻量级预测器,TokenButler根据其上下文的、预测的重要性来优先排序token。相对于用于估计token重要性的SoTA方法,这提高了困惑度和下游准确率超过8%。我们在一个新的合成小上下文共指检索任务上评估TokenButler,证明了接近oracle的准确性。

🔬 方法详解

问题定义:大型语言模型依赖KV-Cache存储token历史,但随着上下文长度增加,KV-Cache成为内存和计算瓶颈。现有方法要么永久删除token导致质量下降,要么检索token块效率低,或者依赖不准确的token重要性代理,无法有效识别和保留关键token。

核心思路:TokenButler的核心思路是训练一个轻量级的、查询感知的预测器,该预测器能够根据token的上下文信息预测其重要性。通过预测token的重要性,TokenButler可以优先保留重要的token,从而在减少KV-Cache大小的同时,尽可能地保持模型的性能。

技术框架:TokenButler包含一个轻量级的预测器,该预测器与LLM并行工作。在每个解码步骤中,预测器接收LLM的输入查询和KV-Cache中的token信息,然后预测每个token的重要性得分。根据这些得分,TokenButler决定哪些token应该保留在KV-Cache中,哪些应该被丢弃。整体流程包括:输入查询 -> 特征提取 -> 重要性预测 -> KV-Cache更新 -> token解码。

关键创新:TokenButler的关键创新在于其高粒度和查询感知的token重要性预测。与现有方法相比,TokenButler能够更准确地识别关键token,并且能够根据不同的输入查询动态地调整token的优先级。此外,TokenButler的轻量级设计使其易于集成到现有的LLM中,而不会引入过多的计算开销。

关键设计:TokenButler的预测器是一个小型神经网络,其输入包括token的上下文信息(例如,token的嵌入向量、token的位置信息)和查询信息(例如,当前解码步骤的输入)。预测器的输出是一个标量值,表示token的重要性得分。损失函数采用交叉熵损失,目标是最小化预测重要性与真实重要性之间的差异。网络结构采用多层感知机,参数量控制在LLM参数量的1.2%以内。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

TokenButler在token重要性估计方面,相对于SoTA方法,困惑度和下游准确率提高了超过8%。在合成小上下文共指检索任务中,TokenButler实现了接近oracle的准确性,证明了其有效性。该方法仅引入小于1.2%的参数开销,具有很高的实用价值。

🎯 应用场景

TokenButler可应用于各种需要处理长上下文的LLM应用场景,例如机器翻译、文本摘要、对话系统和代码生成。通过减少KV-Cache的大小,TokenButler可以降低LLM的内存需求和计算成本,使其更容易部署在资源受限的设备上,并提高LLM的推理速度。此外,TokenButler还可以用于提高LLM在密集、上下文丰富的任务中的性能。

📄 摘要(原文)

Large Language Models (LLMs) rely on the Key-Value (KV) Cache to store token history, enabling efficient decoding of tokens. As the KV-Cache grows, it becomes a major memory and computation bottleneck, however, there is an opportunity to alleviate this bottleneck, especially because prior research has shown that only a small subset of tokens contribute meaningfully to each decoding step. A key challenge in finding these critical tokens is that they are dynamic, and heavily input query-dependent. Existing methods either risk quality by evicting tokens permanently, or retain the full KV-Cache but rely on retrieving chunks (pages) of tokens at generation, failing at dense, context-rich tasks. Additionally, many existing KV-Cache sparsity methods rely on inaccurate proxies for token importance. To address these limitations, we introduce TokenButler, a high-granularity, query-aware predictor that learns to identify these critical tokens. By training a light-weight predictor with less than 1.2% parameter overhead, TokenButler prioritizes tokens based on their contextual, predicted importance. This improves perplexity & downstream accuracy by over 8% relative to SoTA methods for estimating token importance. We evaluate TokenButler on a novel synthetic small-context co-referential retrieval task, demonstrating near-oracle accuracy. Code, models and benchmarks: https://github.com/abdelfattah-lab/TokenButler