Compiler-First State Space Duality and Portable $O(1)$ Autoregressive Caching for Inference

📄 arXiv: 2603.09555v1 📥 PDF

作者: Cosmo Santoni

分类: cs.LG, cs.AI, cs.DC, cs.PF

发布日期: 2026-03-10

备注: 18 pages, 6 figures. Code available at: https://github.com/CosmoNaught/mamba2-jax

🔗 代码/项目: GITHUB


💡 一句话要点

利用XLA编译器优化,实现Mamba-2在多平台上的高效可移植推理。

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 状态空间模型 Mamba-2 XLA编译器 JAX框架 硬件加速 模型推理 可移植性

📋 核心要点

  1. 现有状态空间模型推理依赖特定硬件加速库,可移植性差,限制了其应用范围。
  2. 利用XLA编译器优化,将Mamba-2的状态空间对偶算法映射到XLA优化过程,无需定制内核。
  3. 在CPU、GPU和TPU上实现高效推理,单流预填充达到140 TFLOPS,解码带宽利用率高达64%。

📝 摘要(中文)

当前状态空间模型通常依赖于融合的CUDA和Triton内核,从而对NVIDIA硬件产生硬依赖。本文证明了Mamba-2的状态空间对偶算法——对角状态结构、可分块递归以及具有静态控制流的einsum主导计算——能够很好地映射到XLA的融合和分块优化过程,使得定制内核成为可选而非必需。我们使用XLA下的标准原语实现了完整的推理路径(预填充、缓存的自回归解码),无需手写内核,并将该架构的理论O(1)状态管理实现为编译后的片上缓存,在生成过程中无需主机同步。该实现可以在CPU、NVIDIA GPU和Google Cloud TPU上从单个JAX源无修改地运行。在TPU v6e上,跨越五个模型规模(130M-2.7B参数),XLA生成的代码在单流预填充上达到约140 TFLOPS(15% MFU),在解码上达到高达64%的带宽利用率。贪婪解码在64步内与PyTorch/CUDA参考实现逐token匹配,隐藏状态的一致性在float32舍入容差范围内。该模式可以转移到满足相同结构条件的任何SSM递归,以及任何具有成熟XLA后端的平台。该实现已公开发布在https://github.com/CosmoNaught/mamba2-jax,并已合并到Bonsai JAX模型库中。

🔬 方法详解

问题定义:现有状态空间模型(SSM)的推理通常依赖于特定硬件(如NVIDIA GPU)的CUDA和Triton内核,这导致了模型的可移植性差,难以在不同硬件平台上部署。此外,手动编写和优化这些内核需要大量的专业知识和时间投入。

核心思路:本文的核心思路是利用XLA(Accelerated Linear Algebra)编译器,将Mamba-2的状态空间对偶算法映射到XLA的优化过程。Mamba-2的特殊结构(对角状态结构、可分块递归和einsum主导计算)使其能够很好地适应XLA的融合和分块优化,从而避免了手动编写定制内核的需求。

技术框架:该方法使用JAX框架,将Mamba-2的推理过程(包括预填充和缓存的自回归解码)表示为XLA的标准原语。通过XLA编译器,这些原语被自动优化并编译成可在不同硬件平台上运行的代码。关键在于利用XLA的融合和分块pass,将计算密集型的操作进行优化,从而提高推理效率。此外,该方法还实现了片上缓存,用于存储状态变量,避免了主机同步,进一步提高了推理速度。

关键创新:最重要的技术创新在于发现并利用了Mamba-2的结构特性与XLA编译器优化之间的天然契合。通过将Mamba-2的计算过程表示为XLA的标准原语,可以充分利用XLA的自动优化能力,从而在不同硬件平台上实现高效推理,而无需手动编写定制内核。这种方法极大地提高了模型的可移植性和易用性。

关键设计:该实现的关键设计包括:1) 使用JAX框架,将Mamba-2的计算表示为XLA标准原语;2) 利用XLA的融合和分块pass进行自动优化;3) 实现片上缓存,用于存储状态变量,避免主机同步;4) 确保生成的代码在不同硬件平台上的正确性,通过与PyTorch/CUDA参考实现进行token-for-token的比较来验证。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

该方法在TPU v6e上,跨越五个模型规模(130M-2.7B参数),实现了单流预填充约140 TFLOPS(15% MFU)的性能,解码带宽利用率高达64%。与PyTorch/CUDA参考实现相比,贪婪解码在64步内实现了token-for-token的匹配,隐藏状态的一致性在float32舍入容差范围内,验证了该方法在不同硬件平台上的正确性和高效性。

🎯 应用场景

该研究成果可广泛应用于自然语言处理领域,尤其是在需要高性能和可移植性的场景下,如移动设备上的实时翻译、边缘计算设备上的对话系统等。它降低了状态空间模型部署的门槛,使得更多开发者能够在不同硬件平台上高效地运行这些模型,加速相关技术的落地。

📄 摘要(原文)

State-space model releases are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware. We show that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required. We implement the full inference path (prefill, cached autoregressive decoding) as shaped standard primitives under XLA, without hand-written kernels, and realise the architecture's theoretical $O(1)$ state management as a compiled on-device cache requiring no host synchronisation during generation. The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source. On TPU v6e across five model scales (130M--2.7B parameters), XLA-generated code reaches approximately 140 TFLOPS on single-stream prefill ($15%$ MFU) and up to $64%$ bandwidth utilisation on decode. Greedy decoding matches the PyTorch/CUDA reference token-for-token across 64 steps, with hidden-state agreement within float32 rounding tolerance. The pattern transfers to any SSM recurrence satisfying the same structural conditions, on any platform with a mature XLA backend. The implementation is publicly available at https://github.com/CosmoNaught/mamba2-jax and merged into the Bonsai JAX model library.