Mahjax: A GPU-Accelerated Mahjong Simulator for Reinforcement Learning in JAX

📄 arXiv: 2605.20577v1 📥 PDF

作者: Soichiro Nishimori, Shinri Okano, Keigo Habara, Sotetsu Koyamada, Eason Yu, Masashi Sugiyama

分类: cs.AI, cs.LG

发布日期: 2026-05-20


💡 一句话要点

Mahjax:一款用于JAX强化学习的GPU加速麻将模拟器

🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)

关键词: 麻将 强化学习 JAX GPU加速 并行计算

📋 核心要点

  1. 现有麻将强化学习研究依赖人类棋谱的监督学习,泛化性受限,从零学习算法更具潜力。
  2. Mahjax利用JAX实现完全向量化的麻将环境,支持GPU大规模并行rollout,加速训练过程。
  3. 实验表明,Mahjax在GPU上实现了极高的吞吐量,并验证了其在强化学习中训练agent的有效性。

📝 摘要(中文)

立直麻将是一个多人、不完全信息博弈,其特点是随机性和高维状态空间。这些属性构成了一种独特的挑战组合,反映了强化学习中复杂的现实世界决策问题。虽然之前的研究主要依赖于从人类游戏记录中进行监督学习来预训练策略,但能够从零开始学习的算法(如AlphaZero系列)具有更大的通用性潜力。为了促进此类研究,我们引入了 extbf{Mahjax},这是一个完全向量化的立直麻将环境,用JAX实现,以实现图形处理单元(GPU)上的大规模rollout并行化。我们还提供了一个高质量的可视化工具,以简化调试和与训练好的agent的交互。实验结果表明,在无赤规则和有赤规则下,Mahjax在八个NVIDIA A100 GPU上分别实现了高达 extbf{200万}和 extbf{100万steps per second}的吞吐量。此外,我们通过展示agent可以有效地训练以提高其相对于基线策略的排名,从而验证了该环境对强化学习的实用性。

🔬 方法详解

问题定义:现有麻将强化学习研究主要依赖于人类棋谱的监督学习,这种方法的泛化能力受到限制,无法探索新的策略。从零开始学习(tabula rasa)的算法,如AlphaZero,虽然具有更大的潜力,但麻将环境的复杂性(高维状态空间、随机性、不完全信息)使得训练非常耗时,需要大量的计算资源。

核心思路:Mahjax的核心思路是利用JAX框架的自动微分和GPU加速能力,构建一个高效的、完全向量化的麻将模拟环境。通过在GPU上进行大规模的rollout并行化,可以显著提高训练速度,使得从零开始训练麻将agent成为可能。同时,提供高质量的可视化工具,方便调试和分析agent的行为。

技术框架:Mahjax的整体框架包括以下几个主要模块:1) 麻将规则引擎:负责模拟麻将的游戏规则,包括摸牌、出牌、碰、杠、吃、和牌等。2) 状态表示:将麻将的状态(包括牌局信息、玩家手牌、弃牌堆等)编码成向量形式,作为agent的输入。3) 动作空间:定义agent可以采取的动作,例如出牌、吃、碰、杠、和牌等。4) JAX实现:使用JAX框架实现上述模块,并利用JAX的自动向量化和GPU加速能力。5) 可视化工具:提供图形界面,用于显示牌局状态、agent的决策过程等。

关键创新:Mahjax的关键创新在于其完全向量化的JAX实现,这使得它能够充分利用GPU的并行计算能力,实现极高的吞吐量。与传统的基于CPU的麻将模拟器相比,Mahjax的训练速度提高了几个数量级。此外,Mahjax还提供了一个高质量的可视化工具,方便调试和分析agent的行为。

关键设计:Mahjax的关键设计包括:1) 使用JAX的vmap函数实现自动向量化,使得可以同时模拟多个牌局。2) 使用GPU加速的随机数生成器,用于模拟摸牌和发牌的随机性。3) 使用高效的数据结构来表示麻将的状态,例如使用one-hot编码来表示牌的种类和数量。4) 提供灵活的API,方便用户自定义agent的策略和训练算法。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

Mahjax在八个NVIDIA A100 GPU上实现了极高的吞吐量,在无赤规则下达到200万steps per second,在有赤规则下达到100万steps per second。实验还表明,使用Mahjax可以有效地训练麻将agent,使其能够提高相对于基线策略的排名,验证了该环境在强化学习中的实用性。

🎯 应用场景

Mahjax为麻将强化学习研究提供了一个强大的平台,可用于开发更强大的麻将AI。其技术也可推广到其他复杂的多人博弈游戏,例如德州扑克、星际争霸等,甚至可以应用于现实世界的决策问题,例如资源分配、交通调度等。该研究有助于推动强化学习算法在复杂环境下的应用。

📄 摘要(原文)

Riichi Mahjong is a multi-player, imperfect-information game characterized by stochasticity and high-dimensional state spaces. These attributes present a unique combination of challenges that mirror complex real-world decision-making problems in reinforcement learning. While prior research has heavily relied on supervised learning from human play logs to pre-train the policy, algorithms capable of learning \textit{tabula rasa} (from scratch) offer greater potential for general applicability, as evidenced by the AlphaZero lineage. To facilitate such research, we introduce \textbf{Mahjax}, a fully vectorized Riichi Mahjong environment implemented in JAX to enable large-scale rollout parallelization on Graphics Processing Units (GPUs). We also provide a high-quality visualization tool to streamline debugging and interaction with trained agents. Experimental results demonstrate that Mahjax achieves throughputs of up to \textbf{2 million} and \textbf{1 million steps per second} on eight NVIDIA A100 GPUs under the no-red and red rules, respectively. Furthermore, we validate the environment's utility for reinforcement learning by showing that agents can be trained effectively to improve their rank against baseline policies.