Ragged Paged Attention: A High-Performance and Flexible LLM Inference Kernel for TPU

📄 arXiv: 2604.15464v1 📥 PDF

作者: Jevin Jiang, Ying Chen, Blake A. Hechtman, Fenghui Zhang, Yarong Mu

分类: cs.PF, cs.AI, cs.LG

发布日期: 2026-04-16

备注: 23 pages, 19 figures, 12 tables


💡 一句话要点

提出Ragged Paged Attention,为TPU上的LLM推理提供高性能和灵活的内核。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 大型语言模型 TPU 推理优化 Paged Attention Ragged Batch

📋 核心要点

  1. 现有LLM推理内核主要面向GPU,缺乏针对TPU架构的优化,无法有效处理动态和不规则的执行模式。
  2. Ragged Paged Attention (RPA) 通过细粒度分块、定制软件流水线和分布感知编译策略来优化TPU上的LLM推理。
  3. 在Llama 3 8B模型上,RPA在TPU7x上实现了高达86%的内存带宽利用率和73%的模型FLOPs利用率。

📝 摘要(中文)

大型语言模型(LLM)的部署越来越多地转向像谷歌的张量处理单元(TPU)这样具有成本效益的加速器,这既要考虑性能,也要考虑总体拥有成本(TCO)。然而,现有的LLM推理内核和服务系统主要还是以GPU为中心,并且没有成熟的方法能够有效地将LLM工作负载映射到TPU架构上——尤其是在现代服务中常见的动态和不规则的执行模式下。在本文中,我们提出了Ragged Paged Attention(RPA),这是一个用于TPU的高性能和灵活的attention内核,使用Pallas和Mosaic实现。RPA通过三个关键技术来解决这些挑战:(1)细粒度的分块,以实现对不规则内存的有效动态切片;(2)一个定制的软件流水线,将KV缓存更新与attention计算融合;(3)一种分布感知的编译策略,为decode、prefill和混合工作负载生成专门的内核。在TPU7x上对Llama 3 8B进行评估,RPA在decode中实现了高达86%的内存带宽利用率(MBU),在prefill中实现了73%的模型FLOPs利用率(MFU)。RPA作为vLLM和SGLang中的主要TPU后端集成,为高效的TPU推理提供了生产级的基石,并为内核设计提供了实践性的见解。

🔬 方法详解

问题定义:现有LLM推理框架在TPU上的效率不高,尤其是在处理动态和不规则的ragged batch时。传统的GPU优化方法无法直接应用于TPU架构,导致内存带宽利用率低,计算效率低下。此外,KV缓存的更新和attention计算是两个独立的过程,增加了延迟和资源消耗。

核心思路:RPA的核心思路是通过细粒度的分块和定制的软件流水线来优化TPU上的内存访问和计算。通过将KV缓存更新与attention计算融合,减少了数据传输和同步的开销。分布感知的编译策略则根据不同的工作负载(decode, prefill, mixed)生成专门的内核,进一步提高了效率。

技术框架:RPA的整体框架包括三个主要部分:细粒度分块、定制软件流水线和分布感知编译。细粒度分块将不规则的内存区域划分为更小的块,以便更有效地进行动态切片和内存访问。定制软件流水线将KV缓存更新与attention计算融合,减少了数据传输和同步的开销。分布感知编译根据不同的工作负载生成专门的内核,以优化性能。

关键创新:RPA的关键创新在于其针对TPU架构的优化策略,包括细粒度分块、定制软件流水线和分布感知编译。与传统的GPU优化方法不同,RPA充分利用了TPU的特性,实现了更高的内存带宽利用率和计算效率。将KV缓存更新与attention计算融合也是一个重要的创新,减少了数据传输和同步的开销。

关键设计:RPA的关键设计包括:(1) 使用Pallas和Mosaic进行内核实现,以实现高性能和灵活性;(2) 设计了细粒度的分块策略,以适应不规则的内存布局;(3) 开发了定制的软件流水线,将KV缓存更新与attention计算融合;(4) 实现了分布感知的编译策略,根据不同的工作负载生成专门的内核。具体的参数设置和网络结构细节未在摘要中详细说明,属于未知信息。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

RPA在TPU7x上对Llama 3 8B模型进行了评估,在decode阶段实现了高达86%的内存带宽利用率(MBU),在prefill阶段实现了73%的模型FLOPs利用率(MFU)。这些结果表明,RPA能够有效地利用TPU的硬件资源,实现高性能的LLM推理。RPA作为vLLM和SGLang中的主要TPU后端集成,也证明了其在实际应用中的价值。

🎯 应用场景

RPA可应用于各种需要高性能LLM推理的场景,例如对话机器人、文本生成、代码生成等。通过优化TPU上的推理效率,RPA可以降低部署成本,提高服务质量,并推动LLM在更多领域的应用。RPA作为vLLM和SGLang的TPU后端,也为其他研究者和开发者提供了高效的TPU推理基础。

📄 摘要(原文)

Large Language Model (LLM) deployment is increasingly shifting to cost-efficient accelerators like Google's Tensor Processing Units (TPUs), prioritizing both performance and total cost of ownership (TCO). However, existing LLM inference kernels and serving systems remain largely GPU-centric, and there is no well-established approach for efficiently mapping LLM workloads onto TPU architectures--particularly under the dynamic and ragged execution patterns common in modern serving. In this paper, we present Ragged Paged Attention (RPA), a high-performance and flexible attention kernel for TPUs, implemented using Pallas and Mosaic. RPA addresses these challenges through three key techniques: (1) fine-grained tiling to enable efficient dynamic slicing over ragged memory, (2) a custom software pipeline that fuses KV cache updates with attention computation, and (3) a distribution-aware compilation strategy that generates specialized kernels for decode, prefill, and mixed workloads. Evaluated on Llama 3 8B on TPU7x, RPA achieves up to 86% memory bandwidth utilization (MBU) in decode and 73% model FLOPs utilization (MFU) in prefill. Integrated as the primary TPU backend in vLLM and SGLang, RPA provides a production-grade foundation for efficient TPU inference and offers practical insights into kernel design.