KV-Runahead: Scalable Causal LLM Inference by Parallel Key-Value Cache Generation

📄 arXiv: 2405.05329v2 📥 PDF

作者: Minsik Cho, Mohammad Rastegari, Devang Naik

分类: cs.DC, cs.AI, cs.CL

发布日期: 2024-05-08 (更新: 2024-05-13)

备注: preprint for ICML 2024


💡 一句话要点

提出KV-Runahead并行化方案,加速因果LLM推理的首token生成,提升效率。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 大语言模型 LLM推理 并行计算 KV-cache 首token生成时间 因果注意力 负载均衡

📋 核心要点

  1. 现有LLM推理prompt阶段速度慢,成为瓶颈,传统并行化方案通信开销大。
  2. KV-Runahead并行生成KV-cache,利用因果注意力减少计算,易于实现。
  3. 实验表明,KV-Runahead在Llama 7B和Falcon 7B上分别实现了1.4x和1.6x的加速。

📝 摘要(中文)

本文提出了一种名为KV-Runahead的高效并行化方案,旨在加速大语言模型(LLM)推理的prompt阶段。核心思想是利用extension阶段生成token速度快于prompt阶段的特性,通过并行化多个进程来填充KV-cache,从而最小化首个token生成时间(TTFT)。KV-cache方案具有双重优势:一是利用因果注意力机制,自动减少计算量;二是由于extension阶段已存在KV-cache,易于实现。此外,本文还提出了上下文级别的负载均衡,以处理KV-cache生成的不均匀性,进一步优化TTFT。实验结果表明,与现有的张量并行或序列并行方案相比,KV-Runahead在Llama 7B和Falcon 7B上分别实现了超过1.4倍和1.6倍的加速。

🔬 方法详解

问题定义:论文旨在解决大语言模型(LLM)推理过程中,prompt阶段(即生成第一个token之前的阶段)速度较慢的问题。现有的张量并行或序列并行方案需要进行大量的全局通信(all-gather collectives),导致效率较低,成为LLM推理的瓶颈。

核心思路:论文的核心思路是利用LLM推理过程中extension阶段(生成后续token的阶段)速度快于prompt阶段的特点,通过并行化生成Key-Value Cache (KV-cache) 来加速prompt阶段。由于extension阶段已经存在KV-cache,因此可以复用该机制,降低了实现的复杂度。

技术框架:KV-Runahead方案通过多个进程并行生成KV-cache,每个进程负责一部分上下文的KV-cache生成。整体流程如下:1) 将输入prompt分配给多个进程;2) 每个进程并行计算各自负责部分的KV-cache;3) 将生成的KV-cache合并,用于后续的extension阶段的token生成。为了处理因果注意力机制导致的KV-cache生成不均匀问题,论文还提出了上下文级别的负载均衡策略。

关键创新:KV-Runahead的关键创新在于利用了KV-cache的特性,将prompt阶段的计算转化为并行生成KV-cache的过程。与传统的并行化方案相比,KV-Runahead避免了大量的全局通信,从而提高了效率。此外,KV-Runahead还提出了上下文级别的负载均衡策略,进一步优化了性能。

关键设计:论文的关键设计包括:1) 如何将输入prompt分配给多个进程,以实现负载均衡;2) 如何高效地合并多个进程生成的KV-cache;3) 上下文级别负载均衡的具体实现方式,例如,动态调整每个进程负责的上下文长度,以平衡各个进程的计算负载。具体的参数设置和损失函数等细节在论文中可能未详细描述,属于实现层面的优化。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,KV-Runahead在Llama 7B和Falcon 7B模型上分别实现了超过1.4倍和1.6倍的加速,显著优于现有的张量并行和序列并行方案。这表明KV-Runahead在加速LLM推理方面具有显著优势,能够有效降低首个token生成时间。

🎯 应用场景

KV-Runahead可应用于各种需要快速LLM推理的场景,例如实时对话系统、快速文本生成、搜索引擎等。通过降低首个token生成时间,可以显著提升用户体验,并降低部署成本。该研究对于推动LLM在实际应用中的普及具有重要意义。

📄 摘要(原文)

Large Language Model or LLM inference has two phases, the prompt (or prefill) phase to output the first token and the extension (or decoding) phase to the generate subsequent tokens. In this work, we propose an efficient parallelization scheme, KV-Runahead to accelerate the prompt phase. The key observation is that the extension phase generates tokens faster than the prompt phase because of key-value cache (KV-cache). Hence, KV-Runahead parallelizes the prompt phase by orchestrating multiple processes to populate the KV-cache and minimizes the time-to-first-token (TTFT). Dual-purposing the KV-cache scheme has two main benefits. First, since KV-cache is designed to leverage the causal attention map, we minimize computation and computation automatically. Second, since it already exists for the extension phase, KV-Runahead is easy to implement. We further propose context-level load-balancing to handle uneven KV-cache generation (due to the causal attention) and to optimize TTFT. Compared with an existing parallelization scheme such as tensor or sequential parallelization where keys and values are locally generated and exchanged via all-gather collectives, our experimental results demonstrate that KV-Runahead can offer over 1.4x and 1.6x speedups for Llama 7B and Falcon 7B respectively.