Drama: Mamba-Enabled Model-Based Reinforcement Learning Is Sample and Parameter Efficient
作者: Wenlong Wang, Ivana Dusparic, Yucheng Shi, Ke Zhang, Vinny Cahill
分类: cs.LG, cs.AI, cs.RO
发布日期: 2024-10-11 (更新: 2025-05-16)
备注: Published as a conference paper at ICLR 2025
🔗 代码/项目: GITHUB
💡 一句话要点
Drama:基于Mamba的状态空间模型提升模型强化学习的样本效率和参数效率
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 模型强化学习 状态空间模型 Mamba 序列建模 Atari100k
📋 核心要点
- 传统RNN世界模型难以捕捉长期依赖,Transformer则面临自注意力机制带来的高昂计算成本。
- Drama利用Mamba状态空间模型,实现了线性复杂度的内存和计算,有效捕获长期依赖关系。
- Drama结合新颖的采样方法,缓解了早期训练阶段世界模型不准确带来的次优性,并在Atari100k上取得SOTA结果。
📝 摘要(中文)
基于模型的强化学习(RL)为大多数无模型RL算法中普遍存在的数据低效问题提供了一种解决方案。然而,学习一个鲁棒的世界模型通常需要复杂而深层的架构,这在计算上是昂贵的且训练具有挑战性。在世界模型中,序列模型在准确预测中起着关键作用,并且已经探索了各种架构,每种架构都有其自身的挑战。目前,基于循环神经网络(RNN)的世界模型在处理梯度消失和捕获长期依赖关系方面存在困难。另一方面,Transformer受到自注意力机制的二次方内存和计算复杂度的影响,其复杂度随序列长度n呈O(n^2)增长。为了应对这些挑战,我们提出了一种基于状态空间模型(SSM)的世界模型Drama,特别利用了Mamba,它实现了O(n)的内存和计算复杂度,同时有效地捕获了长期依赖关系,并能够使用更长的序列进行高效训练。我们还引入了一种新的采样方法,以减轻早期训练阶段中不正确的世界模型所导致的次优性。结合这些技术,Drama在Atari100k基准测试上获得了标准化的分数,该分数与其他最先进的(SOTA)基于模型的RL算法相比具有竞争力,并且仅使用了700万参数的世界模型。Drama可以在现成的硬件(如标准笔记本电脑)上访问和训练。我们的代码可在https://github.com/realwenlongwang/Drama.git上找到。
🔬 方法详解
问题定义:论文旨在解决基于模型的强化学习中,世界模型训练效率和模型容量之间的矛盾。现有方法,如基于RNN的模型难以捕捉长期依赖,而基于Transformer的模型则面临计算复杂度过高的问题,限制了其在长序列任务中的应用。
核心思路:论文的核心思路是利用Mamba状态空间模型作为世界模型,Mamba具有线性复杂度,能够高效地处理长序列数据,同时具备捕捉长期依赖关系的能力。此外,论文还提出了一种新的采样方法,以缓解早期训练阶段世界模型不准确带来的影响,从而提高整体训练效率和性能。
技术框架:Drama的整体框架包括以下几个主要模块:1) 环境交互模块:负责与环境进行交互,收集经验数据。2) 世界模型模块:使用Mamba作为核心,学习环境的动态模型。3) 策略优化模块:基于学习到的世界模型,使用强化学习算法优化策略。4) 采样模块:使用提出的新采样方法,从世界模型中采样数据,用于策略优化。整个流程是:环境交互产生数据 -> 世界模型学习 -> 采样 -> 策略优化 -> 更新策略,循环迭代。
关键创新:论文最重要的技术创新点在于将Mamba状态空间模型引入到基于模型的强化学习中,并将其作为世界模型的核心组件。与传统的RNN和Transformer相比,Mamba具有线性复杂度,能够高效地处理长序列数据,并且能够有效地捕捉长期依赖关系。此外,提出的新采样方法也是一个重要的创新点,它能够缓解早期训练阶段世界模型不准确带来的影响,从而提高整体训练效率和性能。
关键设计:Mamba的具体参数设置未知,论文可能使用了标准的Mamba架构。损失函数方面,世界模型的训练可能使用了预测误差作为损失函数,策略优化可能使用了常见的强化学习损失函数,如PPO或SAC。新采样方法的具体细节未知,但其目标是选择更有利于策略优化的样本。
🖼️ 关键图片
📊 实验亮点
Drama在Atari100k基准测试上取得了与SOTA模型强化学习算法相竞争的性能,同时仅使用了700万参数的世界模型,并且可以在标准笔记本电脑上进行训练。这表明Drama在参数效率和计算效率方面具有显著优势,使其更易于部署和应用。
🎯 应用场景
该研究成果可应用于各种需要长期规划和预测的强化学习任务,例如机器人控制、游戏AI、自动驾驶等。通过高效地学习环境动态模型,可以显著提高强化学习算法的样本效率和泛化能力,降低训练成本,并加速实际应用落地。未来,该方法有望扩展到更复杂的环境和任务中,例如多智能体系统和部分可观测环境。
📄 摘要(原文)
Model-based reinforcement learning (RL) offers a solution to the data inefficiency that plagues most model-free RL algorithms. However, learning a robust world model often requires complex and deep architectures, which are computationally expensive and challenging to train. Within the world model, sequence models play a critical role in accurate predictions, and various architectures have been explored, each with its own challenges. Currently, recurrent neural network (RNN)-based world models struggle with vanishing gradients and capturing long-term dependencies. Transformers, on the other hand, suffer from the quadratic memory and computational complexity of self-attention mechanisms, scaling as $O(n^2)$, where $n$ is the sequence length. To address these challenges, we propose a state space model (SSM)-based world model, Drama, specifically leveraging Mamba, that achieves $O(n)$ memory and computational complexity while effectively capturing long-term dependencies and enabling efficient training with longer sequences. We also introduce a novel sampling method to mitigate the suboptimality caused by an incorrect world model in the early training stages. Combining these techniques, Drama achieves a normalised score on the Atari100k benchmark that is competitive with other state-of-the-art (SOTA) model-based RL algorithms, using only a 7 million-parameter world model. Drama is accessible and trainable on off-the-shelf hardware, such as a standard laptop. Our code is available at https://github.com/realwenlongwang/Drama.git.